chenlin commited on
Commit
d9dadf3
1 Parent(s): 6c29f2e
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitattributes copy ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -1,7 +1,126 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+
5
+ import spaces
6
  import gradio as gr
7
+ import torch
8
+
9
+ from llava.conversation import Conversation, conv_templates
10
+ from llava.serve.gradio_utils import (Chat, block_css, learn_more_markdown,
11
+ title_markdown)
12
+
13
+
14
+ def save_video_to_local(video_path):
15
+ filename = os.path.join('temp', next(
16
+ tempfile._get_candidate_names()) + '.mp4')
17
+ shutil.copyfile(video_path, filename)
18
+ return filename
19
+
20
+
21
+ @spaces.GPU(duration=60)
22
+ def generate(video, textbox_in, first_run, state, state_):
23
+ flag = 1
24
+ if not textbox_in:
25
+ if len(state_.messages) > 0:
26
+ textbox_in = state_.messages[-1][1]
27
+ state_.messages.pop(-1)
28
+ flag = 0
29
+ else:
30
+ return "Please enter instruction"
31
+
32
+ video = video if video else "none"
33
+
34
+ if type(state) is not Conversation:
35
+ state = conv_templates[conv_mode].copy()
36
+ state_ = conv_templates[conv_mode].copy()
37
+
38
+ first_run = False if len(state.messages) > 0 else True
39
+
40
+ text_en_out, state_ = handler.generate(
41
+ video, textbox_in, first_run=first_run, state=state_)
42
+ state_.messages[-1] = (state_.roles[1], text_en_out)
43
+
44
+ textbox_out = text_en_out
45
+
46
+ if flag:
47
+ state.append_message(state.roles[0], textbox_in)
48
+ state.append_message(state.roles[1], textbox_out)
49
+ torch.cuda.empty_cache()
50
+ return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), gr.update(value=video if os.path.exists(video) else None, interactive=True))
51
+
52
+
53
+ def clear_history(state, state_):
54
+ state = conv_templates[conv_mode].copy()
55
+ state_ = conv_templates[conv_mode].copy()
56
+ return (gr.update(value=None, interactive=True),
57
+ gr.update(value=None, interactive=True),
58
+ True, state, state_, state.to_gradio_chatbot())
59
+
60
+
61
+ conv_mode = "llava_llama_3"
62
+ model_path = 'Lin-Chen/sharegpt4video-8b'
63
+ device = 'cuda'
64
+ load_8bit = False
65
+ load_4bit = False
66
+ dtype = torch.float16
67
+ handler = Chat(model_path, conv_mode=conv_mode,
68
+ load_8bit=load_8bit, load_4bit=load_8bit, device=device)
69
+
70
+ textbox = gr.Textbox(
71
+ show_label=False, placeholder="Enter text and press ENTER", container=False
72
+ )
73
+ with gr.Blocks(title='ShareGPT4Video-8B🚀', theme=gr.themes.Default(), css=block_css) as demo:
74
+ gr.Markdown(title_markdown)
75
+ state = gr.State()
76
+ state_ = gr.State()
77
+ first_run = gr.State()
78
+
79
+ with gr.Row():
80
+ with gr.Column(scale=3):
81
+ video = gr.Video(label="Input Video")
82
+
83
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
84
+
85
+ with gr.Column(scale=7):
86
+ chatbot = gr.Chatbot(label="ShareGPT4Video-8B",
87
+ bubble_full_width=True)
88
+ with gr.Row():
89
+ with gr.Column(scale=8):
90
+ textbox.render()
91
+ with gr.Column(scale=1, min_width=50):
92
+ submit_btn = gr.Button(
93
+ value="Send", variant="primary", interactive=True
94
+ )
95
+ with gr.Row(elem_id="buttons") as button_row:
96
+ regenerate_btn = gr.Button(
97
+ value="🔄 Regenerate", interactive=True)
98
+ clear_btn = gr.Button(
99
+ value="🗑️ Clear history", interactive=True)
100
+
101
+ with gr.Row():
102
+ gr.Examples(
103
+ examples=[
104
+ [
105
+ f"{cur_dir}/examples/sample_demo_1.mp4",
106
+ "Why is this video funny?",
107
+ ],
108
+ [
109
+ f"{cur_dir}/examples/C_1_0.mp4",
110
+ "Write a poem for this video.",
111
+ ],
112
+ [
113
+ f"{cur_dir}/examples/yoga.mp4",
114
+ "What is happening in this video?",
115
+ ]
116
+ ],
117
+ inputs=[video, textbox],
118
+ )
119
+ gr.Markdown(learn_more_markdown)
120
 
121
+ submit_btn.click(generate, [video, textbox, first_run, state, state_],
122
+ [state, state_, chatbot, first_run, textbox, video])
123
+ clear_btn.click(clear_history, [state, state_],
124
+ [video, textbox, first_run, state, state_, chatbot])
125
 
126
+ demo.launch(server_name='0.0.0.0', server_port=23858, share=True)
 
examples/C_1_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5957657865ccb3c8101e82a5a10bbc7e9c33eb957bd53989098a7c0e42512c70
3
+ size 596629
examples/sample_demo_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc6562a172eb9cb3c760a3c9992349c1faa2c793c112b7b9e50bd5cb17c2164d
3
+ size 1549315
examples/yoga.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74b65d9bec7f83e487b7f923076c01d476dd2ef7ed83928a696ab6f88c7751b7
3
+ size 776184
llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .model import LlavaLlamaForCausalLM
llava/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
llava/conversation.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import dataclasses
3
+ from enum import Enum, auto
4
+ from io import BytesIO
5
+ from typing import Any, List, Tuple
6
+
7
+ from PIL import Image
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+ SINGLE = auto()
14
+ TWO = auto()
15
+ MPT = auto()
16
+ PLAIN = auto()
17
+ LLAMA_2 = auto()
18
+ LLAMA_3 = auto()
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Conversation:
23
+ """A class that keeps all conversation history."""
24
+ system: str
25
+ roles: List[str]
26
+ messages: List[List[str]]
27
+ offset: int
28
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
29
+ sep: str = "###"
30
+ sep2: str = None
31
+ version: str = "Unknown"
32
+
33
+ skip_next: bool = False
34
+
35
+ def get_prompt(self):
36
+ messages = self.messages
37
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
38
+ messages = self.messages.copy()
39
+ init_role, init_msg = messages[0].copy()
40
+ init_msg = init_msg[0].replace("<image>", "").strip()
41
+ if 'mmtag' in self.version:
42
+ messages[0] = (init_role, init_msg)
43
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
44
+ messages.insert(1, (self.roles[1], "Received."))
45
+ else:
46
+ messages[0] = (init_role, "<image>\n" + init_msg)
47
+
48
+ if self.sep_style == SeparatorStyle.SINGLE:
49
+ ret = self.system + self.sep
50
+ for role, message in messages:
51
+ if message:
52
+ if type(message) is tuple:
53
+ message, _, _ = message
54
+ ret += role + ": " + message + self.sep
55
+ else:
56
+ ret += role + ":"
57
+ elif self.sep_style == SeparatorStyle.TWO:
58
+ seps = [self.sep, self.sep2]
59
+ ret = self.system + seps[0]
60
+ for i, (role, message) in enumerate(messages):
61
+ if message:
62
+ if type(message) is tuple:
63
+ message, _, _ = message
64
+ ret += role + ": " + message + seps[i % 2]
65
+ else:
66
+ ret += role + ":"
67
+ elif self.sep_style == SeparatorStyle.MPT:
68
+ ret = self.system + self.sep
69
+ for role, message in messages:
70
+ if message:
71
+ if type(message) is tuple:
72
+ message, _, _ = message
73
+ ret += role + message + self.sep
74
+ else:
75
+ ret += role
76
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
77
+ def wrap_sys(
78
+ msg): return f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
79
+
80
+ def wrap_inst(msg): return f"[INST] {msg} [/INST]"
81
+ ret = ""
82
+
83
+ for i, (role, message) in enumerate(messages):
84
+ if i == 0:
85
+ assert message, "first message should not be none"
86
+ assert role == self.roles[0], "first message should come from user"
87
+ if message:
88
+ if type(message) is tuple:
89
+ message, _, _ = message
90
+ if i == 0:
91
+ message = wrap_sys(self.system) + message
92
+ if i % 2 == 0:
93
+ message = wrap_inst(message)
94
+ ret += self.sep + message
95
+ else:
96
+ ret += " " + message + " " + self.sep2
97
+ else:
98
+ ret += ""
99
+ ret = ret.lstrip(self.sep)
100
+ elif self.sep_style == SeparatorStyle.PLAIN:
101
+ seps = [self.sep, self.sep2]
102
+ ret = self.system
103
+ for i, (role, message) in enumerate(messages):
104
+ if message:
105
+ if type(message) is tuple:
106
+ message, _, _ = message
107
+ ret += message + seps[i % 2]
108
+ else:
109
+ ret += ""
110
+ else:
111
+ raise ValueError(f"Invalid style: {self.sep_style}")
112
+
113
+ return ret
114
+
115
+ def append_message(self, role, message):
116
+ self.messages.append([role, message])
117
+
118
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
119
+ if image_process_mode == "Pad":
120
+ def expand2square(pil_img, background_color=(122, 116, 104)):
121
+ width, height = pil_img.size
122
+ if width == height:
123
+ return pil_img
124
+ elif width > height:
125
+ result = Image.new(
126
+ pil_img.mode, (width, width), background_color)
127
+ result.paste(pil_img, (0, (width - height) // 2))
128
+ return result
129
+ else:
130
+ result = Image.new(
131
+ pil_img.mode, (height, height), background_color)
132
+ result.paste(pil_img, ((height - width) // 2, 0))
133
+ return result
134
+ image = expand2square(image)
135
+ elif image_process_mode in ["Default", "Crop"]:
136
+ pass
137
+ elif image_process_mode == "Resize":
138
+ image = image.resize((336, 336))
139
+ else:
140
+ raise ValueError(
141
+ f"Invalid image_process_mode: {image_process_mode}")
142
+ if max(image.size) > max_len:
143
+ max_hw, min_hw = max(image.size), min(image.size)
144
+ aspect_ratio = max_hw / min_hw
145
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
146
+ longest_edge = int(shortest_edge * aspect_ratio)
147
+ W, H = image.size
148
+ if H > W:
149
+ H, W = longest_edge, shortest_edge
150
+ else:
151
+ H, W = shortest_edge, longest_edge
152
+ image = image.resize((W, H))
153
+ if return_pil:
154
+ return image
155
+ else:
156
+ buffered = BytesIO()
157
+ image.save(buffered, format=image_format)
158
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
159
+ return img_b64_str
160
+
161
+ def get_images(self, return_pil=False):
162
+ images = []
163
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
164
+ if i % 2 == 0:
165
+ if type(msg) is tuple:
166
+ msg, image, image_process_mode = msg
167
+ image = self.process_image(
168
+ image, image_process_mode, return_pil=return_pil)
169
+ images.append(image)
170
+ return images
171
+
172
+ def to_gradio_chatbot(self):
173
+ ret = []
174
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
175
+ if i % 2 == 0:
176
+ if type(msg) is tuple:
177
+ msg, image, image_process_mode = msg
178
+ img_b64_str = self.process_image(
179
+ image, "Default", return_pil=False,
180
+ image_format='JPEG')
181
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
182
+ msg = img_str + msg.replace('<image>', '').strip()
183
+ ret.append([msg, None])
184
+ else:
185
+ ret.append([msg, None])
186
+ else:
187
+ ret[-1][-1] = msg
188
+ return ret
189
+
190
+ def copy(self):
191
+ return Conversation(
192
+ system=self.system,
193
+ roles=self.roles,
194
+ messages=[[x, y] for x, y in self.messages],
195
+ offset=self.offset,
196
+ sep_style=self.sep_style,
197
+ sep=self.sep,
198
+ sep2=self.sep2,
199
+ version=self.version)
200
+
201
+ def dict(self):
202
+ if len(self.get_images()) > 0:
203
+ return {
204
+ "system": self.system,
205
+ "roles": self.roles,
206
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
207
+ "offset": self.offset,
208
+ "sep": self.sep,
209
+ "sep2": self.sep2,
210
+ }
211
+ return {
212
+ "system": self.system,
213
+ "roles": self.roles,
214
+ "messages": self.messages,
215
+ "offset": self.offset,
216
+ "sep": self.sep,
217
+ "sep2": self.sep2,
218
+ }
219
+
220
+
221
+ conv_vicuna_v0 = Conversation(
222
+ system="A chat between a curious human and an artificial intelligence assistant. "
223
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
224
+ roles=("Human", "Assistant"),
225
+ messages=(
226
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
227
+ ("Assistant",
228
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
229
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
230
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
231
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
232
+ "renewable and non-renewable energy sources:\n"
233
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
234
+ "energy sources are finite and will eventually run out.\n"
235
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
236
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
237
+ "and other negative effects.\n"
238
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
239
+ "have lower operational costs than non-renewable sources.\n"
240
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
241
+ "locations than non-renewable sources.\n"
242
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
243
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
244
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
245
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
246
+ ),
247
+ offset=2,
248
+ sep_style=SeparatorStyle.SINGLE,
249
+ sep="###",
250
+ )
251
+
252
+ conv_vicuna_v1 = Conversation(
253
+ system="A chat between a curious user and an artificial intelligence assistant. "
254
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
255
+ roles=("USER", "ASSISTANT"),
256
+ version="v1",
257
+ messages=(),
258
+ offset=0,
259
+ sep_style=SeparatorStyle.TWO,
260
+ sep=" ",
261
+ sep2="</s>",
262
+ )
263
+
264
+ conv_llama_2 = Conversation(
265
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
266
+
267
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="llama_v2",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.LLAMA_2,
273
+ sep="<s>",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_llava_llama_2 = Conversation(
278
+ system="You are a helpful language and vision assistant. "
279
+ "You are able to understand the visual content that the user provides, "
280
+ "and assist the user with a variety of tasks using natural language.",
281
+ roles=("USER", "ASSISTANT"),
282
+ version="llama_v2",
283
+ messages=(),
284
+ offset=0,
285
+ sep_style=SeparatorStyle.LLAMA_2,
286
+ sep="<s>",
287
+ sep2="</s>",
288
+ )
289
+
290
+ conv_mpt = Conversation(
291
+ system="""<|im_start|>system
292
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
293
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
294
+ version="mpt",
295
+ messages=(),
296
+ offset=0,
297
+ sep_style=SeparatorStyle.MPT,
298
+ sep="<|im_end|>",
299
+ )
300
+
301
+ conv_llava_plain = Conversation(
302
+ system="",
303
+ roles=("", ""),
304
+ messages=(
305
+ ),
306
+ offset=0,
307
+ sep_style=SeparatorStyle.PLAIN,
308
+ sep="\n",
309
+ )
310
+
311
+ conv_llava_v0 = Conversation(
312
+ system="A chat between a curious human and an artificial intelligence assistant. "
313
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
314
+ roles=("Human", "Assistant"),
315
+ messages=(
316
+ ),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.SINGLE,
319
+ sep="###",
320
+ )
321
+
322
+ conv_llava_v0_mmtag = Conversation(
323
+ system="A chat between a curious user and an artificial intelligence assistant. "
324
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
325
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
326
+ roles=("Human", "Assistant"),
327
+ messages=(
328
+ ),
329
+ offset=0,
330
+ sep_style=SeparatorStyle.SINGLE,
331
+ sep="###",
332
+ version="v0_mmtag",
333
+ )
334
+
335
+ conv_llava_v1 = Conversation(
336
+ system="A chat between a curious human and an artificial intelligence assistant. "
337
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
338
+ roles=("USER", "ASSISTANT"),
339
+ version="v1",
340
+ messages=(),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.TWO,
343
+ sep=" ",
344
+ sep2="</s>",
345
+ )
346
+
347
+ conv_llava_v1_mmtag = Conversation(
348
+ system="A chat between a curious user and an artificial intelligence assistant. "
349
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
350
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
351
+ roles=("USER", "ASSISTANT"),
352
+ messages=(),
353
+ offset=0,
354
+ sep_style=SeparatorStyle.TWO,
355
+ sep=" ",
356
+ sep2="</s>",
357
+ version="v1_mmtag",
358
+ )
359
+
360
+ conv_mistral_instruct = Conversation(
361
+ system="",
362
+ roles=("USER", "ASSISTANT"),
363
+ version="llama_v2",
364
+ messages=(),
365
+ offset=0,
366
+ sep_style=SeparatorStyle.LLAMA_2,
367
+ sep="",
368
+ sep2="</s>",
369
+ )
370
+
371
+ conv_chatml_direct = Conversation(
372
+ system="""<|im_start|>system
373
+ Answer the questions.""",
374
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
375
+ version="mpt",
376
+ messages=(),
377
+ offset=0,
378
+ sep_style=SeparatorStyle.MPT,
379
+ sep="<|im_end|>",
380
+ )
381
+
382
+ conv_yi = Conversation(
383
+ system="""<|im_start|>system\nAnswer the questions.""",
384
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
385
+ version="yi",
386
+ messages=(),
387
+ offset=0,
388
+ sep_style=SeparatorStyle.MPT,
389
+ sep="<|im_end|>\n",
390
+ )
391
+
392
+ conv_llava_llama_3 = Conversation(
393
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
394
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n",
395
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"),
396
+ version="llama3",
397
+ messages=[],
398
+ offset=0,
399
+ sep_style=SeparatorStyle.MPT,
400
+ sep="<|eot_id|>",
401
+ )
402
+
403
+ default_conversation = conv_vicuna_v1
404
+ conv_templates = {
405
+ "default": conv_vicuna_v0,
406
+ "v0": conv_vicuna_v0,
407
+ "v1": conv_vicuna_v1,
408
+ "vicuna_v1": conv_vicuna_v1,
409
+ "llama_2": conv_llama_2,
410
+ "mistral_instruct": conv_mistral_instruct,
411
+ "chatml_direct": conv_chatml_direct,
412
+ "mistral_direct": conv_chatml_direct,
413
+
414
+ "plain": conv_llava_plain,
415
+ "v0_plain": conv_llava_plain,
416
+ "llava_v0": conv_llava_v0,
417
+ "v0_mmtag": conv_llava_v0_mmtag,
418
+ "llava_v1": conv_llava_v1,
419
+ "v1_mmtag": conv_llava_v1_mmtag,
420
+ "llava_llama_2": conv_llava_llama_2,
421
+ "llava_llama_3": conv_llava_llama_3,
422
+
423
+ "mpt": conv_mpt,
424
+ }
425
+
426
+
427
+ if __name__ == "__main__":
428
+ print(default_conversation.get_prompt())
llava/mm_utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+
8
+ from transformers import StoppingCriteria
9
+ from llava.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def select_best_resolution(original_size, possible_resolutions):
13
+ """
14
+ Selects the best resolution from a list of possible resolutions based on the original size.
15
+
16
+ Args:
17
+ original_size (tuple): The original size of the image in the format (width, height).
18
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
+
20
+ Returns:
21
+ tuple: The best fit resolution in the format (width, height).
22
+ """
23
+ original_width, original_height = original_size
24
+ best_fit = None
25
+ max_effective_resolution = 0
26
+ min_wasted_resolution = float('inf')
27
+
28
+ for width, height in possible_resolutions:
29
+ scale = min(width / original_width, height / original_height)
30
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
+ wasted_resolution = (width * height) - effective_resolution
33
+
34
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
+ max_effective_resolution = effective_resolution
36
+ min_wasted_resolution = wasted_resolution
37
+ best_fit = (width, height)
38
+
39
+ return best_fit
40
+
41
+
42
+ def resize_and_pad_image(image, target_resolution):
43
+ """
44
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
45
+
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+
50
+ Returns:
51
+ PIL.Image.Image: The resized and padded image.
52
+ """
53
+ original_width, original_height = image.size
54
+ target_width, target_height = target_resolution
55
+
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+
74
+ return new_image
75
+
76
+
77
+ def divide_to_patches(image, patch_size):
78
+ """
79
+ Divides an image into patches of a specified size.
80
+
81
+ Args:
82
+ image (PIL.Image.Image): The input image.
83
+ patch_size (int): The size of each patch.
84
+
85
+ Returns:
86
+ list: A list of PIL.Image.Image objects representing the patches.
87
+ """
88
+ patches = []
89
+ width, height = image.size
90
+ for i in range(0, height, patch_size):
91
+ for j in range(0, width, patch_size):
92
+ box = (j, i, j + patch_size, i + patch_size)
93
+ patch = image.crop(box)
94
+ patches.append(patch)
95
+
96
+ return patches
97
+
98
+
99
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
+ """
101
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
+
103
+ Args:
104
+ image_size (tuple): The size of the input image in the format (width, height).
105
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
106
+ patch_size (int): The size of each image patch.
107
+
108
+ Returns:
109
+ tuple: The shape of the image patch grid in the format (width, height).
110
+ """
111
+ if type(grid_pinpoints) is list:
112
+ possible_resolutions = grid_pinpoints
113
+ else:
114
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
115
+ width, height = select_best_resolution(image_size, possible_resolutions)
116
+ return width // patch_size, height // patch_size
117
+
118
+
119
+ def process_anyres_image(image, processor, grid_pinpoints):
120
+ """
121
+ Process an image with variable resolutions.
122
+
123
+ Args:
124
+ image (PIL.Image.Image): The input image to be processed.
125
+ processor: The image processor object.
126
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
127
+
128
+ Returns:
129
+ torch.Tensor: A tensor containing the processed image patches.
130
+ """
131
+ if type(grid_pinpoints) is list:
132
+ possible_resolutions = grid_pinpoints
133
+ else:
134
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
135
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
136
+ image_padded = resize_and_pad_image(image, best_resolution)
137
+
138
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
+
140
+ shortest_edge = processor.size['shortest_edge'] if isinstance(
141
+ processor.size, dict) else min(processor.size[0], processor.size[1])
142
+ image_original_resize = image.resize(
143
+ (shortest_edge, shortest_edge))
144
+
145
+ image_patches = [image_original_resize] + patches
146
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
147
+ for image_patch in image_patches]
148
+ return torch.stack(image_patches, dim=0)
149
+
150
+
151
+ def load_image_from_base64(image):
152
+ return Image.open(BytesIO(base64.b64decode(image)))
153
+
154
+
155
+ def expand2square(pil_img, background_color):
156
+ width, height = pil_img.size
157
+ if width == height:
158
+ return pil_img
159
+ elif width > height:
160
+ result = Image.new(pil_img.mode, (width, width), background_color)
161
+ result.paste(pil_img, (0, (width - height) // 2))
162
+ return result
163
+ else:
164
+ result = Image.new(pil_img.mode, (height, height), background_color)
165
+ result.paste(pil_img, ((height - width) // 2, 0))
166
+ return result
167
+
168
+
169
+ def process_images(images, image_processor, model_cfg):
170
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
171
+ new_images = []
172
+ if image_aspect_ratio == 'pad':
173
+ for image in images:
174
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
175
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
176
+ new_images.append(image)
177
+ elif image_aspect_ratio == "anyres":
178
+ for image in images:
179
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
180
+ new_images.append(image)
181
+ else:
182
+ return image_processor(images, return_tensors='pt')['pixel_values']
183
+ if all(x.shape == new_images[0].shape for x in new_images):
184
+ new_images = torch.stack(new_images, dim=0)
185
+ return new_images
186
+
187
+
188
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
189
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
190
+
191
+ def insert_separator(X, sep):
192
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
193
+
194
+ input_ids = []
195
+ offset = 0
196
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
197
+ offset = 1
198
+ input_ids.append(prompt_chunks[0][0])
199
+
200
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
201
+ input_ids.extend(x[offset:])
202
+
203
+ if return_tensors is not None:
204
+ if return_tensors == 'pt':
205
+ return torch.tensor(input_ids, dtype=torch.long)
206
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
207
+ return input_ids
208
+
209
+
210
+ def get_model_name_from_path(model_path):
211
+ model_path = model_path.strip("/")
212
+ model_paths = model_path.split("/")
213
+ if model_paths[-1].startswith('checkpoint-'):
214
+ return model_paths[-2] + "_" + model_paths[-1]
215
+ else:
216
+ return model_paths[-1]
217
+
218
+ class KeywordsStoppingCriteria(StoppingCriteria):
219
+ def __init__(self, keywords, tokenizer, input_ids):
220
+ self.keywords = keywords
221
+ self.keyword_ids = []
222
+ self.max_keyword_len = 0
223
+ for keyword in keywords:
224
+ cur_keyword_ids = tokenizer(keyword).input_ids
225
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
226
+ cur_keyword_ids = cur_keyword_ids[1:]
227
+ if len(cur_keyword_ids) > self.max_keyword_len:
228
+ self.max_keyword_len = len(cur_keyword_ids)
229
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
230
+ self.tokenizer = tokenizer
231
+ self.start_len = input_ids.shape[1]
232
+
233
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
234
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
235
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
236
+ for keyword_id in self.keyword_ids:
237
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
238
+ if torch.equal(truncated_output_ids, keyword_id):
239
+ return True
240
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
241
+ for keyword in self.keywords:
242
+ if keyword in outputs:
243
+ return True
244
+ return False
245
+
246
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
247
+ outputs = []
248
+ for i in range(output_ids.shape[0]):
249
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
250
+ return all(outputs)
llava/model/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ try:
2
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3
+ from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4
+ from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5
+ except:
6
+ pass
llava/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import shutil
18
+ import warnings
19
+
20
+ import torch
21
+ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
22
+ BitsAndBytesConfig)
23
+
24
+ from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
25
+ DEFAULT_IMAGE_PATCH_TOKEN)
26
+ from llava.model import *
27
+ from llava.train.train import smart_tokenizer_and_embedding_resize
28
+
29
+
30
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, lora_alpha=None, **kwargs):
31
+ kwargs = {"device_map": device_map, **kwargs}
32
+
33
+ if device != "cuda":
34
+ kwargs['device_map'] = {"": device}
35
+
36
+ if load_8bit:
37
+ kwargs['load_in_8bit'] = True
38
+ elif load_4bit:
39
+ kwargs['load_in_4bit'] = True
40
+ kwargs['quantization_config'] = BitsAndBytesConfig(
41
+ load_in_4bit=True,
42
+ bnb_4bit_compute_dtype=torch.float16,
43
+ bnb_4bit_use_double_quant=True,
44
+ bnb_4bit_quant_type='nf4'
45
+ )
46
+ else:
47
+ kwargs['torch_dtype'] = torch.float16
48
+
49
+ if use_flash_attn:
50
+ kwargs['attn_implementation'] = 'flash_attention_2'
51
+
52
+ if 'llava' or 'sharegpt4video' in model_name.lower():
53
+ # Load LLaVA model
54
+ if 'lora' in model_name.lower() and model_base is None:
55
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
56
+ if 'lora' in model_name.lower() and model_base is not None:
57
+ from llava.model.language_model.llava_llama import LlavaConfig
58
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
59
+ tokenizer = AutoTokenizer.from_pretrained(
60
+ model_base, use_fast=False, model_max_length=lora_cfg_pretrained.tokenizer_model_max_length)
61
+ print('Loading LLaVA from base model...')
62
+ model = LlavaLlamaForCausalLM.from_pretrained(
63
+ model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
64
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
65
+ if model.lm_head.weight.shape[0] != token_num:
66
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(
67
+ token_num, tokem_dim, device=model.device, dtype=model.dtype))
68
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(
69
+ token_num, tokem_dim, device=model.device, dtype=model.dtype))
70
+ print('Loading additional LLaVA weights...')
71
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
72
+ non_lora_trainables = torch.load(os.path.join(
73
+ model_path, 'non_lora_trainables.bin'), map_location='cpu')
74
+ else:
75
+ # this is probably from HF Hub
76
+ from huggingface_hub import hf_hub_download
77
+
78
+ def load_from_hf(repo_id, filename, subfolder=None):
79
+ cache_file = hf_hub_download(
80
+ repo_id=repo_id,
81
+ filename=filename,
82
+ subfolder=subfolder)
83
+ return torch.load(cache_file, map_location='cpu')
84
+ non_lora_trainables = load_from_hf(
85
+ model_path, 'non_lora_trainables.bin')
86
+ non_lora_trainables = {(k[11:] if k.startswith(
87
+ 'base_model.') else k): v for k, v in non_lora_trainables.items()}
88
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
89
+ non_lora_trainables = {(k[6:] if k.startswith(
90
+ 'model.') else k): v for k, v in non_lora_trainables.items()}
91
+ model.load_state_dict(non_lora_trainables, strict=False)
92
+
93
+ from peft import PeftModel
94
+ print('Loading LoRA weights...')
95
+ if lora_alpha is not None:
96
+ print("Lora Scaling:", lora_alpha/128)
97
+ model = PeftModel.from_pretrained(
98
+ model, model_path, lora_alpha=lora_alpha, torch_device='cpu')
99
+ else:
100
+ model = PeftModel.from_pretrained(model, model_path, torch_device='cpu')
101
+ print('Merging LoRA weights...')
102
+ model = model.merge_and_unload()
103
+ print('Model is loaded...')
104
+ elif model_base is not None:
105
+ # this may be mm projector only
106
+ print('Loading LLaVA from base model...')
107
+ if 'mpt' in model_name.lower():
108
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
109
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(
110
+ model_path, 'configuration_mpt.py'))
111
+ tokenizer = AutoTokenizer.from_pretrained(
112
+ model_base, use_fast=True)
113
+ cfg_pretrained = AutoConfig.from_pretrained(
114
+ model_path, trust_remote_code=True)
115
+ model = LlavaMptForCausalLM.from_pretrained(
116
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
117
+ else:
118
+ tokenizer = AutoTokenizer.from_pretrained(
119
+ model_base, use_fast=False)
120
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
121
+ model = LlavaLlamaForCausalLM.from_pretrained(
122
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
123
+
124
+ mm_projector_weights = torch.load(os.path.join(
125
+ model_path, 'mm_projector.bin'), map_location='cpu')
126
+ mm_projector_weights = {k: v.to(torch.float16)
127
+ for k, v in mm_projector_weights.items()}
128
+ model.load_state_dict(mm_projector_weights, strict=False)
129
+ else:
130
+ if 'mpt' in model_name.lower():
131
+ tokenizer = AutoTokenizer.from_pretrained(
132
+ model_path, use_fast=True)
133
+ model = LlavaMptForCausalLM.from_pretrained(
134
+ model_path, low_cpu_mem_usage=True, **kwargs)
135
+ elif 'mistral' in model_name.lower():
136
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
137
+ model = LlavaMistralForCausalLM.from_pretrained(
138
+ model_path,
139
+ low_cpu_mem_usage=True,
140
+ **kwargs
141
+ )
142
+ else:
143
+ tokenizer = AutoTokenizer.from_pretrained(
144
+ model_path, use_fast=False)
145
+ model = LlavaLlamaForCausalLM.from_pretrained(
146
+ model_path,
147
+ low_cpu_mem_usage=True,
148
+ **kwargs
149
+ )
150
+ else:
151
+ # Load language model
152
+ if model_base is not None:
153
+ # PEFT model
154
+ from peft import PeftModel
155
+ tokenizer = AutoTokenizer.from_pretrained(
156
+ model_base, use_fast=False)
157
+ model = AutoModelForCausalLM.from_pretrained(
158
+ model_base, low_cpu_mem_usage=True, **kwargs)
159
+ print(f"Loading LoRA weights from {model_path}")
160
+ model = PeftModel.from_pretrained(model, model_path)
161
+ print(f"Merging weights")
162
+ model = model.merge_and_unload()
163
+ print('Convert to FP16...')
164
+ model.to(torch.float16)
165
+ else:
166
+ use_fast = False
167
+ if 'mpt' in model_name.lower():
168
+ tokenizer = AutoTokenizer.from_pretrained(
169
+ model_path, use_fast=True)
170
+ model = AutoModelForCausalLM.from_pretrained(
171
+ model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
172
+ else:
173
+ tokenizer = AutoTokenizer.from_pretrained(
174
+ model_path, use_fast=False)
175
+ model = AutoModelForCausalLM.from_pretrained(
176
+ model_path, low_cpu_mem_usage=True, **kwargs)
177
+
178
+ image_processor = None
179
+
180
+ if 'llava' or 'sharegpt4video' in model_name.lower():
181
+ mm_use_im_start_end = getattr(
182
+ model.config, "mm_use_im_start_end", False)
183
+ mm_use_im_patch_token = getattr(
184
+ model.config, "mm_use_im_patch_token", True)
185
+ if mm_use_im_patch_token:
186
+ tokenizer.add_tokens(
187
+ [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
188
+ if mm_use_im_start_end:
189
+ tokenizer.add_tokens(
190
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
191
+
192
+ vision_tower = model.get_vision_tower()
193
+ if not vision_tower.is_loaded:
194
+ vision_tower.load_model(device_map=device_map)
195
+ if device_map != 'auto':
196
+ vision_tower.to(device=device_map, dtype=torch.float16)
197
+ image_processor = vision_tower.image_processor
198
+
199
+ if hasattr(model.config, "max_sequence_length"):
200
+ context_len = model.config.max_sequence_length
201
+ else:
202
+ context_len = 2048
203
+
204
+ return tokenizer, model, image_processor, context_len
llava/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from llava.model import *
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
llava/model/language_model/llava_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(LlamaConfig):
31
+ model_type = "llava_llama"
32
+
33
+
34
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(LlavaLlamaModel, self).__init__(config)
39
+
40
+
41
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = LlavaLlamaModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ **kwargs
72
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
73
+
74
+ if inputs_embeds is None:
75
+ (
76
+ input_ids,
77
+ position_ids,
78
+ attention_mask,
79
+ past_key_values,
80
+ inputs_embeds,
81
+ labels
82
+ ) = self.prepare_inputs_labels_for_multimodal(
83
+ input_ids,
84
+ position_ids,
85
+ attention_mask,
86
+ past_key_values,
87
+ labels,
88
+ images,
89
+ image_sizes
90
+ )
91
+
92
+ return super().forward(
93
+ input_ids=input_ids,
94
+ attention_mask=attention_mask,
95
+ position_ids=position_ids,
96
+ past_key_values=past_key_values,
97
+ inputs_embeds=inputs_embeds,
98
+ labels=labels,
99
+ use_cache=use_cache,
100
+ output_attentions=output_attentions,
101
+ output_hidden_states=output_hidden_states,
102
+ return_dict=return_dict
103
+ )
104
+
105
+ @torch.no_grad()
106
+ def generate(
107
+ self,
108
+ inputs: Optional[torch.Tensor] = None,
109
+ images: Optional[torch.Tensor] = None,
110
+ image_sizes: Optional[torch.Tensor] = None,
111
+ **kwargs,
112
+ ) -> Union[GenerateOutput, torch.LongTensor]:
113
+ position_ids = kwargs.pop("position_ids", None)
114
+ attention_mask = kwargs.pop("attention_mask", None)
115
+ if "inputs_embeds" in kwargs:
116
+ raise NotImplementedError("`inputs_embeds` is not supported")
117
+
118
+ if images is not None:
119
+ (
120
+ inputs,
121
+ position_ids,
122
+ attention_mask,
123
+ _,
124
+ inputs_embeds,
125
+ _
126
+ ) = self.prepare_inputs_labels_for_multimodal(
127
+ inputs,
128
+ position_ids,
129
+ attention_mask,
130
+ None,
131
+ None,
132
+ images,
133
+ image_sizes=image_sizes
134
+ )
135
+ else:
136
+ inputs_embeds = self.get_model().embed_tokens(inputs)
137
+
138
+ return super().generate(
139
+ position_ids=position_ids,
140
+ attention_mask=attention_mask,
141
+ inputs_embeds=inputs_embeds,
142
+ **kwargs
143
+ )
144
+
145
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
146
+ inputs_embeds=None, **kwargs):
147
+ images = kwargs.pop("images", None)
148
+ image_sizes = kwargs.pop("image_sizes", None)
149
+ inputs = super().prepare_inputs_for_generation(
150
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
151
+ )
152
+ if images is not None:
153
+ inputs['images'] = images
154
+ if image_sizes is not None:
155
+ inputs['image_sizes'] = image_sizes
156
+ return inputs
157
+
158
+ AutoConfig.register("llava_llama", LlavaConfig)
159
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ MistralConfig, MistralModel, MistralForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+ from transformers.generation.utils import GenerateOutput
27
+
28
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29
+
30
+
31
+ class LlavaMistralConfig(MistralConfig):
32
+ model_type = "llava_mistral"
33
+
34
+
35
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
36
+ config_class = LlavaMistralConfig
37
+
38
+ def __init__(self, config: MistralConfig):
39
+ super(LlavaMistralModel, self).__init__(config)
40
+
41
+
42
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43
+ config_class = LlavaMistralConfig
44
+
45
+ def __init__(self, config):
46
+ super(MistralForCausalLM, self).__init__(config)
47
+ self.model = LlavaMistralModel(config)
48
+
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("llava_mistral", LlavaMistralConfig)
158
+ AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
llava/model/language_model/llava_mpt.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from transformers import AutoConfig, AutoModelForCausalLM, \
21
+ MptConfig, MptForCausalLM, MptModel
22
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23
+
24
+
25
+ class LlavaMptConfig(MptConfig):
26
+ model_type = "llava_mpt"
27
+
28
+
29
+ class LlavaMptModel(LlavaMetaModel, MptModel):
30
+ config_class = LlavaMptConfig
31
+
32
+ def __init__(self, config: MptConfig):
33
+ config.hidden_size = config.d_model
34
+ super(LlavaMptModel, self).__init__(config)
35
+
36
+ def embed_tokens(self, x):
37
+ return self.wte(x)
38
+
39
+
40
+ class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41
+ config_class = LlavaMptConfig
42
+ supports_gradient_checkpointing = True
43
+
44
+ def __init__(self, config):
45
+ super(MptForCausalLM, self).__init__(config)
46
+
47
+ self.transformer = LlavaMptModel(config)
48
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.transformer
55
+
56
+ def _set_gradient_checkpointing(self, module, value=False):
57
+ if isinstance(module, LlavaMptModel):
58
+ module.gradient_checkpointing = value
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64
+ attention_mask: Optional[torch.Tensor] = None,
65
+ inputs_embeds: Optional[torch.Tensor] = None,
66
+ labels: Optional[torch.Tensor] = None,
67
+ use_cache: Optional[bool] = None,
68
+ output_attentions: Optional[bool] = None,
69
+ output_hidden_states: Optional[bool] = None,
70
+ return_dict: Optional[bool] = None,
71
+ images=None):
72
+
73
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74
+
75
+ return super().forward(
76
+ input_ids,
77
+ past_key_values=past_key_values,
78
+ attention_mask=attention_mask,
79
+ inputs_embeds=inputs_embeds,
80
+ labels=labels,
81
+ use_cache=use_cache,
82
+ output_attentions=output_attentions,
83
+ output_hidden_states=output_hidden_states,
84
+ return_dict=return_dict,
85
+ )
86
+
87
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88
+ images = kwargs.pop("images", None)
89
+ _inputs = super().prepare_inputs_for_generation(
90
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91
+ )
92
+ _inputs['images'] = images
93
+ return _inputs
94
+
95
+
96
+ AutoConfig.register("llava_mpt", LlavaMptConfig)
97
+ AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
llava/model/llava_arch.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+ from llava.mm_utils import get_anyres_image_grid_shape
27
+
28
+
29
+ class LlavaMetaModel:
30
+
31
+ def __init__(self, config):
32
+ super(LlavaMetaModel, self).__init__(config)
33
+ if hasattr(config, "mm_vision_tower"):
34
+ self.vision_tower = build_vision_tower(config, delay_load=True)
35
+ self.mm_projector = build_vision_projector(config)
36
+
37
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
38
+ self.image_newline = nn.Parameter(
39
+ torch.empty(config.hidden_size, dtype=self.dtype)
40
+ )
41
+
42
+ def get_vision_tower(self):
43
+ vision_tower = getattr(self, 'vision_tower', None)
44
+ if type(vision_tower) is list:
45
+ vision_tower = vision_tower[0]
46
+ return vision_tower
47
+
48
+ def initialize_vision_modules(self, model_args, fsdp=None):
49
+ vision_tower = model_args.vision_tower
50
+ mm_vision_select_layer = model_args.mm_vision_select_layer
51
+ mm_vision_select_feature = model_args.mm_vision_select_feature
52
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
53
+ mm_patch_merge_type = model_args.mm_patch_merge_type
54
+
55
+ self.config.mm_vision_tower = vision_tower
56
+
57
+ if self.get_vision_tower() is None:
58
+ vision_tower = build_vision_tower(model_args)
59
+
60
+ if fsdp is not None and len(fsdp) > 0:
61
+ self.vision_tower = [vision_tower]
62
+ else:
63
+ self.vision_tower = vision_tower
64
+ else:
65
+ if fsdp is not None and len(fsdp) > 0:
66
+ vision_tower = self.vision_tower[0]
67
+ else:
68
+ vision_tower = self.vision_tower
69
+ vision_tower.load_model()
70
+
71
+ self.config.use_mm_proj = True
72
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
73
+ self.config.mm_hidden_size = vision_tower.hidden_size
74
+ self.config.mm_vision_select_layer = mm_vision_select_layer
75
+ self.config.mm_vision_select_feature = mm_vision_select_feature
76
+ self.config.mm_patch_merge_type = mm_patch_merge_type
77
+
78
+ if getattr(self, 'mm_projector', None) is None:
79
+ self.mm_projector = build_vision_projector(self.config)
80
+
81
+ if 'unpad' in mm_patch_merge_type:
82
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
83
+ self.image_newline = nn.Parameter(
84
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
85
+ )
86
+ else:
87
+ # In case it is frozen by LoRA
88
+ for p in self.mm_projector.parameters():
89
+ p.requires_grad = True
90
+
91
+ if pretrain_mm_mlp_adapter is not None:
92
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
93
+ def get_w(weights, keyword):
94
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
95
+
96
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
97
+
98
+
99
+ def unpad_image(tensor, original_size):
100
+ """
101
+ Unpads a PyTorch tensor of a padded and resized image.
102
+
103
+ Args:
104
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
105
+ original_size (tuple): The original size of PIL image (width, height).
106
+
107
+ Returns:
108
+ torch.Tensor: The unpadded image tensor.
109
+ """
110
+ original_width, original_height = original_size
111
+ current_height, current_width = tensor.shape[1:]
112
+
113
+ original_aspect_ratio = original_width / original_height
114
+ current_aspect_ratio = current_width / current_height
115
+
116
+ if original_aspect_ratio > current_aspect_ratio:
117
+ scale_factor = current_width / original_width
118
+ new_height = int(original_height * scale_factor)
119
+ padding = (current_height - new_height) // 2
120
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
121
+ else:
122
+ scale_factor = current_height / original_height
123
+ new_width = int(original_width * scale_factor)
124
+ padding = (current_width - new_width) // 2
125
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
126
+
127
+ return unpadded_tensor
128
+
129
+
130
+ class LlavaMetaForCausalLM(ABC):
131
+
132
+ @abstractmethod
133
+ def get_model(self):
134
+ pass
135
+
136
+ def get_vision_tower(self):
137
+ return self.get_model().get_vision_tower()
138
+
139
+ def encode_images(self, images):
140
+ image_features = self.get_model().get_vision_tower()(images)
141
+ image_features = self.get_model().mm_projector(image_features)
142
+ return image_features
143
+
144
+ def prepare_inputs_labels_for_multimodal(
145
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
146
+ images, image_sizes=None
147
+ ):
148
+ vision_tower = self.get_vision_tower()
149
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
150
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
151
+
152
+ if type(images) is list or images.ndim == 5:
153
+ if type(images) is list:
154
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
155
+ concat_images = torch.cat([image for image in images], dim=0)
156
+ image_features = self.encode_images(concat_images)
157
+ split_sizes = [image.shape[0] for image in images]
158
+ image_features = torch.split(image_features, split_sizes, dim=0)
159
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
160
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
161
+ if mm_patch_merge_type == 'flat':
162
+ image_features = [x.flatten(0, 1) for x in image_features]
163
+ elif mm_patch_merge_type.startswith('spatial'):
164
+ new_image_features = []
165
+ for image_idx, image_feature in enumerate(image_features):
166
+ if image_feature.shape[0] > 1:
167
+ base_image_feature = image_feature[0]
168
+ image_feature = image_feature[1:]
169
+ height = width = self.get_vision_tower().num_patches_per_side
170
+ assert height * width == base_image_feature.shape[0]
171
+ if image_aspect_ratio == 'anyres':
172
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
173
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
174
+ else:
175
+ raise NotImplementedError
176
+ if 'unpad' in mm_patch_merge_type:
177
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
178
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
179
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
180
+ image_feature = torch.cat((
181
+ image_feature,
182
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
183
+ ), dim=-1)
184
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
185
+ else:
186
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
187
+ image_feature = image_feature.flatten(0, 3)
188
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
189
+ else:
190
+ image_feature = image_feature[0]
191
+ if 'unpad' in mm_patch_merge_type:
192
+ image_feature = torch.cat((
193
+ image_feature,
194
+ self.model.image_newline[None].to(image_feature.device)
195
+ ), dim=0)
196
+ new_image_features.append(image_feature)
197
+ image_features = new_image_features
198
+ else:
199
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
200
+ else:
201
+ image_features = self.encode_images(images)
202
+
203
+ # TODO: image start / end is not implemented here to support pretraining.
204
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
205
+ raise NotImplementedError
206
+
207
+ # Let's just add dummy tensors if they do not exist,
208
+ # it is a headache to deal with None all the time.
209
+ # But it is not ideal, and if you have a better idea,
210
+ # please open an issue / submit a PR, thanks.
211
+ _labels = labels
212
+ _position_ids = position_ids
213
+ _attention_mask = attention_mask
214
+ if attention_mask is None:
215
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
216
+ else:
217
+ attention_mask = attention_mask.bool()
218
+ if position_ids is None:
219
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
220
+ if labels is None:
221
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
222
+
223
+ # remove the padding using attention_mask -- FIXME
224
+ _input_ids = input_ids
225
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
226
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
227
+
228
+ new_input_embeds = []
229
+ new_labels = []
230
+ cur_image_idx = 0
231
+ for batch_idx, cur_input_ids in enumerate(input_ids):
232
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
233
+ if num_images == 0:
234
+ cur_image_features = image_features[cur_image_idx]
235
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
236
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
237
+ new_input_embeds.append(cur_input_embeds)
238
+ new_labels.append(labels[batch_idx])
239
+ cur_image_idx += 1
240
+ continue
241
+
242
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
243
+ cur_input_ids_noim = []
244
+ cur_labels = labels[batch_idx]
245
+ cur_labels_noim = []
246
+ for i in range(len(image_token_indices) - 1):
247
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
248
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
249
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
250
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
251
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
252
+ cur_new_input_embeds = []
253
+ cur_new_labels = []
254
+
255
+ for i in range(num_images + 1):
256
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
257
+ cur_new_labels.append(cur_labels_noim[i])
258
+ if i < num_images:
259
+ cur_image_features = image_features[cur_image_idx]
260
+ cur_image_idx += 1
261
+ cur_new_input_embeds.append(cur_image_features)
262
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
263
+
264
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
265
+
266
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
267
+ cur_new_labels = torch.cat(cur_new_labels)
268
+
269
+ new_input_embeds.append(cur_new_input_embeds)
270
+ new_labels.append(cur_new_labels)
271
+
272
+ # Truncate sequences to max length as image embeddings can make the sequence longer
273
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
274
+ if tokenizer_model_max_length is not None:
275
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
276
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
277
+
278
+ # Combine them
279
+ max_len = max(x.shape[0] for x in new_input_embeds)
280
+ batch_size = len(new_input_embeds)
281
+
282
+ new_input_embeds_padded = []
283
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
284
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
285
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
286
+
287
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
288
+ cur_len = cur_new_embed.shape[0]
289
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
290
+ new_input_embeds_padded.append(torch.cat((
291
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
292
+ cur_new_embed
293
+ ), dim=0))
294
+ if cur_len > 0:
295
+ new_labels_padded[i, -cur_len:] = cur_new_labels
296
+ attention_mask[i, -cur_len:] = True
297
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
298
+ else:
299
+ new_input_embeds_padded.append(torch.cat((
300
+ cur_new_embed,
301
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
302
+ ), dim=0))
303
+ if cur_len > 0:
304
+ new_labels_padded[i, :cur_len] = cur_new_labels
305
+ attention_mask[i, :cur_len] = True
306
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
307
+
308
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
309
+
310
+ if _labels is None:
311
+ new_labels = None
312
+ else:
313
+ new_labels = new_labels_padded
314
+
315
+ if _attention_mask is None:
316
+ attention_mask = None
317
+ else:
318
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
319
+
320
+ if _position_ids is None:
321
+ position_ids = None
322
+
323
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
324
+
325
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
326
+ if model_args.mm_use_im_patch_token:
327
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
328
+ self.resize_token_embeddings(len(tokenizer))
329
+
330
+ if model_args.mm_use_im_start_end:
331
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
332
+ self.resize_token_embeddings(len(tokenizer))
333
+
334
+ if num_new_tokens > 0:
335
+ input_embeddings = self.get_input_embeddings().weight.data
336
+ output_embeddings = self.get_output_embeddings().weight.data
337
+
338
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
339
+ dim=0, keepdim=True)
340
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
341
+ dim=0, keepdim=True)
342
+
343
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
344
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
345
+
346
+ if model_args.tune_mm_mlp_adapter:
347
+ for p in self.get_input_embeddings().parameters():
348
+ p.requires_grad = True
349
+ for p in self.get_output_embeddings().parameters():
350
+ p.requires_grad = False
351
+
352
+ if model_args.pretrain_mm_mlp_adapter:
353
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
354
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
355
+ assert num_new_tokens == 2
356
+ if input_embeddings.shape == embed_tokens_weight.shape:
357
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
358
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
359
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
360
+ else:
361
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
362
+ elif model_args.mm_use_im_patch_token:
363
+ if model_args.tune_mm_mlp_adapter:
364
+ for p in self.get_input_embeddings().parameters():
365
+ p.requires_grad = False
366
+ for p in self.get_output_embeddings().parameters():
367
+ p.requires_grad = False
llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3
+ from .siglip_encoder import SigLipVisionTower
4
+
5
+
6
+ def build_vision_tower(vision_tower_cfg, **kwargs):
7
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8
+ is_absolute_path_exists = os.path.exists(vision_tower)
9
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
10
+ if 'siglip' not in vision_tower.lower():
11
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
12
+ if use_s2:
13
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
14
+ else:
15
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16
+ else:
17
+ if is_absolute_path_exists or vision_tower.startswith("google") or vision_tower.startswith('bczhou'):
18
+ return SigLipVisionTower(vision_tower, vision_tower_cfg, **kwargs)
19
+
20
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
20
+ self.load_model()
21
+ else:
22
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
+
24
+ def load_model(self, device_map=None):
25
+ if self.is_loaded:
26
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27
+ return
28
+ print(f'Load vision tower from {self.vision_tower_name}')
29
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
31
+ self.vision_tower.requires_grad_(False)
32
+
33
+ self.is_loaded = True
34
+
35
+ def feature_select(self, image_forward_outs):
36
+ image_features = image_forward_outs.hidden_states[self.select_layer]
37
+ if self.select_feature == 'patch':
38
+ image_features = image_features[:, 1:]
39
+ elif self.select_feature == 'cls_patch':
40
+ image_features = image_features
41
+ else:
42
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
43
+ return image_features
44
+
45
+ # @torch.no_grad()
46
+ def forward(self, images):
47
+ if type(images) is list:
48
+ image_features = []
49
+ for image in images:
50
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
52
+ image_features.append(image_feature)
53
+ else:
54
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
56
+
57
+ return image_features
58
+
59
+ @property
60
+ def dummy_feature(self):
61
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62
+
63
+ @property
64
+ def dtype(self):
65
+ return self.vision_tower.dtype
66
+
67
+ @property
68
+ def device(self):
69
+ return self.vision_tower.device
70
+
71
+ @property
72
+ def config(self):
73
+ if self.is_loaded:
74
+ return self.vision_tower.config
75
+ else:
76
+ return self.cfg_only
77
+
78
+ @property
79
+ def hidden_size(self):
80
+ return self.config.hidden_size
81
+
82
+ @property
83
+ def num_patches_per_side(self):
84
+ return self.config.image_size // self.config.patch_size
85
+
86
+ @property
87
+ def num_patches(self):
88
+ return (self.config.image_size // self.config.patch_size) ** 2
89
+
90
+
91
+
92
+ class CLIPVisionTowerS2(CLIPVisionTower):
93
+ def __init__(self, vision_tower, args, delay_load=False):
94
+ super().__init__(vision_tower, args, delay_load)
95
+
96
+ self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
97
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
98
+ self.s2_scales.sort()
99
+ self.s2_split_size = self.s2_scales[0]
100
+ self.s2_image_size = self.s2_scales[-1]
101
+
102
+ try:
103
+ from s2wrapper import forward as multiscale_forward
104
+ except ImportError:
105
+ raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
106
+ self.multiscale_forward = multiscale_forward
107
+
108
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
109
+ if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
110
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
111
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
112
+
113
+ def load_model(self, device_map=None):
114
+ if self.is_loaded:
115
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
116
+ return
117
+
118
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
119
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
120
+ self.vision_tower.requires_grad_(False)
121
+
122
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
123
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
124
+
125
+ self.is_loaded = True
126
+
127
+ # @torch.no_grad()
128
+ def forward_feature(self, images):
129
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
130
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
131
+ return image_features
132
+
133
+ # @torch.no_grad()
134
+ def forward(self, images):
135
+ if type(images) is list:
136
+ image_features = []
137
+ for image in images:
138
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
139
+ image_features.append(image_feature)
140
+ else:
141
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
142
+
143
+ return image_features
144
+
145
+ @property
146
+ def hidden_size(self):
147
+ return self.config.hidden_size * len(self.s2_scales)
llava/model/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
3
+ '''
4
+
5
+ from typing import Optional, Tuple, Union, Dict
6
+ from dataclasses import dataclass
7
+ from functools import partial, reduce
8
+ from PIL import Image
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+ import os
13
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
14
+ from transformers.image_transforms import (
15
+ convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
16
+ from transformers.image_utils import (
17
+ ChannelDimension, PILImageResampling, to_numpy_array, )
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers import PretrainedConfig
22
+ from transformers.utils import ModelOutput
23
+ from transformers.image_processing_utils import BaseImageProcessor
24
+
25
+
26
+ class SigLipImageProcessor(BaseImageProcessor):
27
+ def __init__(self,
28
+ image_mean=(0.5, 0.5, 0.5),
29
+ image_std=(0.5, 0.5, 0.5),
30
+ size=(384, 384),
31
+ crop_size: Dict[str, int] = None,
32
+ resample=PILImageResampling.BICUBIC,
33
+ rescale_factor=1 / 255,
34
+ data_format=ChannelDimension.FIRST):
35
+ crop_size = crop_size if crop_size is not None else {
36
+ "height": 384, "width": 384}
37
+ crop_size = get_size_dict(
38
+ crop_size, default_to_square=True, param_name="crop_size")
39
+
40
+ self.image_mean = image_mean
41
+ self.image_std = image_std
42
+ self.size = size
43
+ self.resample = resample
44
+ self.rescale_factor = rescale_factor
45
+ self.data_format = data_format
46
+ self.crop_size = crop_size
47
+
48
+ def preprocess(self, images, return_tensors):
49
+ if isinstance(images, Image.Image):
50
+ images = [images]
51
+ else:
52
+ assert isinstance(images, list)
53
+
54
+ transforms = [
55
+ convert_to_rgb,
56
+ to_numpy_array,
57
+ partial(resize, size=self.size, resample=self.resample,
58
+ data_format=self.data_format),
59
+ partial(rescale, scale=self.rescale_factor,
60
+ data_format=self.data_format),
61
+ partial(normalize, mean=self.image_mean,
62
+ std=self.image_std, data_format=self.data_format),
63
+ partial(to_channel_dimension_format, channel_dim=self.data_format,
64
+ input_channel_dim=self.data_format),
65
+ ]
66
+
67
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
68
+ data = {"pixel_values": images}
69
+
70
+ return BatchFeature(data=data, tensor_type=return_tensors)
71
+
72
+
73
+ class SigLipVisionConfig(PretrainedConfig):
74
+ model_type = "siglip_vision_model"
75
+
76
+ def __init__(
77
+ self,
78
+ hidden_size=1152,
79
+ image_mean=(0.5, 0.5, 0.5),
80
+ intermediate_size=4304,
81
+ num_hidden_layers=27,
82
+ num_attention_heads=16,
83
+ num_channels=3,
84
+ image_size=384,
85
+ patch_size=14,
86
+ hidden_act="gelu_pytorch_tanh",
87
+ layer_norm_eps=1e-6,
88
+ attention_dropout=0.0,
89
+ **kwargs,
90
+ ):
91
+ super().__init__(**kwargs)
92
+
93
+ self.hidden_size = hidden_size
94
+ self.intermediate_size = intermediate_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.num_channels = num_channels
98
+ self.patch_size = patch_size
99
+ self.image_size = image_size
100
+ self.attention_dropout = attention_dropout
101
+ self.layer_norm_eps = layer_norm_eps
102
+ self.hidden_act = hidden_act
103
+ self.image_mean = image_mean
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
107
+ cls._set_token_in_kwargs(kwargs)
108
+
109
+ config_dict, kwargs = cls.get_config_dict(
110
+ pretrained_model_name_or_path, **kwargs)
111
+
112
+ # get the vision config dict if we are loading from SigLipConfig
113
+ if config_dict.get("model_type") == "siglip":
114
+ config_dict = config_dict["vision_config"]
115
+
116
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
117
+ logger.warning(
118
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
119
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
120
+ )
121
+
122
+ return cls.from_dict(config_dict, **kwargs)
123
+
124
+
125
+ @dataclass
126
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
127
+ class SigLipVisionModelOutput(ModelOutput):
128
+ """
129
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
130
+
131
+ Args:
132
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
133
+ The image embeddings obtained by applying the projection layer to the pooler_output.
134
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
135
+ Sequence of hidden-states at the output of the last layer of the model.
136
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
137
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
138
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
139
+
140
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
141
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
142
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
143
+ sequence_length)`.
144
+
145
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
146
+ heads.
147
+ """
148
+
149
+ image_embeds: Optional[torch.FloatTensor] = None
150
+ last_hidden_state: torch.FloatTensor = None
151
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
152
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
153
+
154
+
155
+ class SigLipVisionEmbeddings(nn.Module):
156
+ def __init__(self, config: SigLipVisionConfig):
157
+ super().__init__()
158
+ self.config = config
159
+ self.embed_dim = config.hidden_size
160
+ self.image_size = config.image_size
161
+ self.patch_size = config.patch_size
162
+
163
+ self.patch_embedding = nn.Conv2d(
164
+ in_channels=config.num_channels,
165
+ out_channels=self.embed_dim,
166
+ kernel_size=self.patch_size,
167
+ stride=self.patch_size,
168
+ padding="valid",
169
+ )
170
+
171
+ self.num_patches = (self.image_size // self.patch_size) ** 2
172
+ self.num_positions = self.num_patches
173
+ self.position_embedding = nn.Embedding(
174
+ self.num_positions, self.embed_dim)
175
+ self.register_buffer("position_ids", torch.arange(
176
+ self.num_positions).expand((1, -1)), persistent=False)
177
+
178
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
179
+ patch_embeds = self.patch_embedding(
180
+ pixel_values) # shape = [*, width, grid, grid]
181
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
182
+
183
+ embeddings = embeddings + self.position_embedding(self.position_ids)
184
+ return embeddings
185
+
186
+
187
+ class SigLipAttention(nn.Module):
188
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
189
+
190
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.config = config
194
+ self.embed_dim = config.hidden_size
195
+ self.num_heads = config.num_attention_heads
196
+ self.head_dim = self.embed_dim // self.num_heads
197
+ if self.head_dim * self.num_heads != self.embed_dim:
198
+ raise ValueError(
199
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
200
+ f" {self.num_heads})."
201
+ )
202
+ self.scale = self.head_dim ** -0.5
203
+ self.dropout = config.attention_dropout
204
+
205
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
206
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
207
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
208
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
209
+
210
+ def forward(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ attention_mask: Optional[torch.Tensor] = None,
214
+ output_attentions: Optional[bool] = False,
215
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
216
+ """Input shape: Batch x Time x Channel"""
217
+
218
+ batch_size, q_len, _ = hidden_states.size()
219
+
220
+ query_states = self.q_proj(hidden_states)
221
+ key_states = self.k_proj(hidden_states)
222
+ value_states = self.v_proj(hidden_states)
223
+
224
+ query_states = query_states.view(
225
+ batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
226
+ key_states = key_states.view(
227
+ batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
228
+ value_states = value_states.view(
229
+ batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
230
+
231
+ k_v_seq_len = key_states.shape[-2]
232
+ attn_weights = torch.matmul(
233
+ query_states, key_states.transpose(2, 3)) * self.scale
234
+
235
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
236
+ raise ValueError(
237
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
238
+ f" {attn_weights.size()}"
239
+ )
240
+
241
+ if attention_mask is not None:
242
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
243
+ raise ValueError(
244
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
245
+ )
246
+ attn_weights = attn_weights + attention_mask
247
+
248
+ # upcast attention to fp32
249
+ attn_weights = nn.functional.softmax(
250
+ attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
251
+ attn_weights = nn.functional.dropout(
252
+ attn_weights, p=self.dropout, training=self.training)
253
+ attn_output = torch.matmul(attn_weights, value_states)
254
+
255
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
256
+ raise ValueError(
257
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
258
+ f" {attn_output.size()}"
259
+ )
260
+
261
+ attn_output = attn_output.transpose(1, 2).contiguous()
262
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
263
+
264
+ attn_output = self.out_proj(attn_output)
265
+
266
+ return attn_output, attn_weights
267
+
268
+
269
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
270
+ class SigLipMLP(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ self.config = config
274
+ self.activation_fn = ACT2FN[config.hidden_act]
275
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
276
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
277
+
278
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
279
+ hidden_states = self.fc1(hidden_states)
280
+ hidden_states = self.activation_fn(hidden_states)
281
+ hidden_states = self.fc2(hidden_states)
282
+ return hidden_states
283
+
284
+
285
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
286
+ class SigLipEncoderLayer(nn.Module):
287
+ def __init__(self, config: SigLipVisionConfig):
288
+ super().__init__()
289
+ self.embed_dim = config.hidden_size
290
+ self.self_attn = SigLipAttention(config)
291
+ self.layer_norm1 = nn.LayerNorm(
292
+ self.embed_dim, eps=config.layer_norm_eps)
293
+ self.mlp = SigLipMLP(config)
294
+ self.layer_norm2 = nn.LayerNorm(
295
+ self.embed_dim, eps=config.layer_norm_eps)
296
+
297
+ # Ignore copy
298
+ def forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ attention_mask: torch.Tensor,
302
+ output_attentions: Optional[bool] = False,
303
+ ) -> Tuple[torch.FloatTensor]:
304
+ """
305
+ Args:
306
+ hidden_states (`torch.FloatTensor`):
307
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
308
+ attention_mask (`torch.FloatTensor`):
309
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
310
+ output_attentions (`bool`, *optional*, defaults to `False`):
311
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
312
+ returned tensors for more detail.
313
+ """
314
+ residual = hidden_states
315
+
316
+ hidden_states = self.layer_norm1(hidden_states)
317
+ hidden_states, attn_weights = self.self_attn(
318
+ hidden_states=hidden_states,
319
+ attention_mask=attention_mask,
320
+ output_attentions=output_attentions,
321
+ )
322
+ hidden_states = residual + hidden_states
323
+
324
+ residual = hidden_states
325
+ hidden_states = self.layer_norm2(hidden_states)
326
+ hidden_states = self.mlp(hidden_states)
327
+ hidden_states = residual + hidden_states
328
+
329
+ outputs = (hidden_states,)
330
+
331
+ if output_attentions:
332
+ outputs += (attn_weights,)
333
+
334
+ return outputs
335
+
336
+
337
+ class SigLipPreTrainedModel(PreTrainedModel):
338
+ """
339
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
340
+ models.
341
+ """
342
+
343
+ config_class = SigLipVisionConfig
344
+ base_model_prefix = "siglip"
345
+ supports_gradient_checkpointing = True
346
+
347
+ def _init_weights(self, module):
348
+ """Initialize the weights"""
349
+ pass
350
+
351
+
352
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
353
+ class SigLipEncoder(nn.Module):
354
+ """
355
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
356
+ [`SigLipEncoderLayer`].
357
+
358
+ Args:
359
+ config: SigLipVisionConfig
360
+ """
361
+
362
+ def __init__(self, config: SigLipVisionConfig):
363
+ super().__init__()
364
+ self.config = config
365
+ self.layers = nn.ModuleList(
366
+ [SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
367
+ self.gradient_checkpointing = False
368
+
369
+ # Ignore copy
370
+ def forward(
371
+ self,
372
+ inputs_embeds,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ output_attentions: Optional[bool] = None,
375
+ output_hidden_states: Optional[bool] = None,
376
+ return_dict: Optional[bool] = None,
377
+ ) -> Union[Tuple, BaseModelOutput]:
378
+ r"""
379
+ Args:
380
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
381
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
382
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
383
+ than the model's internal embedding lookup matrix.
384
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
385
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
386
+
387
+ - 1 for tokens that are **not masked**,
388
+ - 0 for tokens that are **masked**.
389
+
390
+ [What are attention masks?](../glossary#attention-mask)
391
+ output_attentions (`bool`, *optional*):
392
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
393
+ returned tensors for more detail.
394
+ output_hidden_states (`bool`, *optional*):
395
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
396
+ for more detail.
397
+ return_dict (`bool`, *optional*):
398
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
399
+ """
400
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
401
+ output_hidden_states = (
402
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
403
+ )
404
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
405
+
406
+ encoder_states = () if output_hidden_states else None
407
+ all_attentions = () if output_attentions else None
408
+
409
+ hidden_states = inputs_embeds
410
+ for encoder_layer in self.layers:
411
+ if output_hidden_states:
412
+ encoder_states = encoder_states + (hidden_states,)
413
+ if self.gradient_checkpointing and self.training:
414
+ layer_outputs = self._gradient_checkpointing_func(
415
+ encoder_layer.__call__,
416
+ hidden_states,
417
+ attention_mask,
418
+ output_attentions,
419
+ )
420
+ else:
421
+ layer_outputs = encoder_layer(
422
+ hidden_states,
423
+ attention_mask,
424
+ output_attentions=output_attentions,
425
+ )
426
+
427
+ hidden_states = layer_outputs[0]
428
+
429
+ if output_attentions:
430
+ all_attentions = all_attentions + (layer_outputs[1],)
431
+
432
+ if output_hidden_states:
433
+ encoder_states = encoder_states + (hidden_states,)
434
+
435
+ if not return_dict:
436
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
437
+ return BaseModelOutput(
438
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
439
+ )
440
+
441
+
442
+ class SigLipVisionTransformer(nn.Module):
443
+ def __init__(self, config: SigLipVisionConfig):
444
+ super().__init__()
445
+ self.config = config
446
+ embed_dim = config.hidden_size
447
+
448
+ self.embeddings = SigLipVisionEmbeddings(config)
449
+ self.encoder = SigLipEncoder(config)
450
+ self.post_layernorm = nn.LayerNorm(
451
+ embed_dim, eps=config.layer_norm_eps)
452
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
453
+
454
+ def forward(
455
+ self,
456
+ pixel_values,
457
+ output_attentions: Optional[bool] = None,
458
+ output_hidden_states: Optional[bool] = None,
459
+ return_dict: Optional[bool] = None,
460
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
461
+ r"""
462
+ Returns:
463
+
464
+ """
465
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
466
+ output_hidden_states = (
467
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
468
+ )
469
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
470
+
471
+ hidden_states = self.embeddings(pixel_values)
472
+
473
+ encoder_outputs = self.encoder(
474
+ inputs_embeds=hidden_states,
475
+ output_attentions=output_attentions,
476
+ output_hidden_states=output_hidden_states,
477
+ return_dict=return_dict,
478
+ )
479
+
480
+ last_hidden_state = encoder_outputs[0]
481
+ last_hidden_state = self.post_layernorm(last_hidden_state)
482
+
483
+ pooled_output = self.head(last_hidden_state)
484
+
485
+ if not return_dict:
486
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
487
+
488
+ return BaseModelOutputWithPooling(
489
+ last_hidden_state=last_hidden_state,
490
+ pooler_output=pooled_output,
491
+ hidden_states=encoder_outputs.hidden_states,
492
+ attentions=encoder_outputs.attentions,
493
+ )
494
+
495
+
496
+ class SigLipMultiheadAttentionPoolingHead(nn.Module):
497
+ """Multihead Attention Pooling."""
498
+
499
+ def __init__(self, config: SigLipVisionConfig):
500
+ super().__init__()
501
+
502
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
503
+ self.attention = torch.nn.MultiheadAttention(
504
+ config.hidden_size, config.num_attention_heads, batch_first=True)
505
+ self.layernorm = nn.LayerNorm(
506
+ config.hidden_size, eps=config.layer_norm_eps)
507
+ self.mlp = SigLipMLP(config)
508
+
509
+ def forward(self, hidden_state):
510
+ batch_size = hidden_state.shape[0]
511
+ probe = self.probe.repeat(batch_size, 1, 1)
512
+
513
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
514
+
515
+ residual = hidden_state
516
+ hidden_state = self.layernorm(hidden_state)
517
+ hidden_state = residual + self.mlp(hidden_state)
518
+
519
+ return hidden_state[:, 0]
520
+
521
+
522
+ class SigLipVisionModel(SigLipPreTrainedModel):
523
+ config_class = SigLipVisionConfig
524
+ main_input_name = "pixel_values"
525
+ _no_split_modules = ["SigLipEncoderLayer"]
526
+
527
+ def __init__(self, config: SigLipVisionConfig):
528
+ super().__init__(config)
529
+
530
+ self.vision_model = SigLipVisionTransformer(config)
531
+ del self.vision_model.encoder.layers[-1:]
532
+ self.vision_model.head = nn.Identity()
533
+
534
+ # Initialize weights and apply final processing
535
+ self.post_init()
536
+
537
+ def get_input_embeddings(self) -> nn.Module:
538
+ return self.vision_model.embeddings.patch_embedding
539
+
540
+ def forward(
541
+ self,
542
+ pixel_values,
543
+ output_attentions: Optional[bool] = None,
544
+ output_hidden_states: Optional[bool] = None,
545
+ return_dict: Optional[bool] = None,
546
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
547
+ r"""
548
+ Returns:
549
+
550
+ Examples:
551
+
552
+ ```python
553
+ >>> from PIL import Image
554
+ >>> import requests
555
+ >>> from transformers import AutoProcessor, SigLipVisionModel
556
+
557
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
558
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
559
+
560
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
561
+ >>> image = Image.open(requests.get(url, stream=True).raw)
562
+
563
+ >>> inputs = processor(images=image, return_tensors="pt")
564
+
565
+ >>> outputs = model(**inputs)
566
+ >>> last_hidden_state = outputs.last_hidden_state
567
+ >>> pooled_output = outputs.pooler_output # pooled features
568
+ ```"""
569
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
570
+
571
+ return self.vision_model(
572
+ pixel_values=pixel_values,
573
+ output_attentions=output_attentions,
574
+ output_hidden_states=output_hidden_states,
575
+ return_dict=return_dict,
576
+ )
577
+
578
+
579
+ class SigLipVisionTower(nn.Module):
580
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
581
+ super().__init__()
582
+
583
+ self.is_loaded = False
584
+
585
+ self.config = SigLipVisionConfig()
586
+
587
+ self.vision_tower_name = vision_tower
588
+
589
+ self.image_processor = SigLipImageProcessor()
590
+
591
+ if not delay_load:
592
+ self.load_model()
593
+ else:
594
+ self.cfg_only = self.config
595
+
596
+ def load_model(self, device_map=None):
597
+ if self.is_loaded:
598
+ print('{} is already loaded, `load_model` called again, skipping.'.format(
599
+ self.vision_tower_name))
600
+ return
601
+
602
+ self.vision_tower = SigLipVisionModel.from_pretrained(
603
+ self.vision_tower_name, device_map=device_map)
604
+
605
+ self.vision_tower.requires_grad_(False)
606
+ self.vision_tower.eval()
607
+
608
+ self.is_loaded = True
609
+
610
+ # @torch.no_grad()
611
+ def forward(self, images):
612
+ if type(images) is list:
613
+ image_features = []
614
+ for image in images:
615
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
616
+ output_hidden_states=True)
617
+ image_feature = image_forward_out.hidden_states[-1].to(
618
+ image.dtype)
619
+ assert image_features.shape[-2] == 729
620
+ image_features.append(image_feature)
621
+ else:
622
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
623
+ output_hidden_states=True)
624
+ image_features = image_forward_outs.hidden_states[-1].to(
625
+ images.dtype)
626
+ assert image_features.shape[-2] == 729
627
+
628
+ return image_features
629
+
630
+ @property
631
+ def dummy_feature(self):
632
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
633
+
634
+ @property
635
+ def dtype(self):
636
+ for p in self.vision_tower.parameters():
637
+ return p.dtype
638
+
639
+ @property
640
+ def device(self):
641
+ for p in self.vision_tower.parameters():
642
+ return p.device
643
+
644
+ @property
645
+ def hidden_size(self):
646
+ return self.config.hidden_size
647
+
648
+ @property
649
+ def num_patches_per_side(self):
650
+ return self.config.image_size // self.config.patch_size
651
+
652
+ @property
653
+ def num_patches(self):
654
+ return (self.config.image_size // self.config.patch_size) ** 2
llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
35
+
36
+ if projector_type == 'linear':
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
+ if mlp_gelu_match:
41
+ mlp_depth = int(mlp_gelu_match.group(1))
42
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
+ for _ in range(1, mlp_depth):
44
+ modules.append(nn.GELU())
45
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
+ return nn.Sequential(*modules)
47
+
48
+ if projector_type == 'identity':
49
+ return IdentityMap()
50
+
51
+ raise ValueError(f'Unknown projector type: {projector_type}')
llava/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'llava' in config and 'llava' not in cfg.model_type:
7
+ assert cfg.model_type == 'llama'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)
llava/serve/gradio_utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from decord import VideoReader, cpu
4
+ from PIL import Image
5
+
6
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
7
+ from llava.conversation import conv_templates
8
+ from llava.mm_utils import (KeywordsStoppingCriteria, get_model_name_from_path,
9
+ process_images, tokenizer_image_token)
10
+ from llava.model.builder import load_pretrained_model
11
+ from llava.utils import disable_torch_init
12
+
13
+ title_markdown = ("""
14
+ <div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;">
15
+ <div style="margin-right: 20px; display: flex; align-items: center;">
16
+ <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video" style="text-decoration: none; display: flex; align-items: center;">
17
+ <img src="https://raw.githubusercontent.com/ShareGPT4V/ShareGPT4V-Resources/master/images/share4video_tight.png" alt="ShareGPT4Video🚀" style="max-width: 120px; height: auto;">
18
+ </a>
19
+ </div>
20
+ <div>
21
+ <h1>ShareGPT4Video: Improving Video Understanding and Generation with Better Captions</h1>
22
+ <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
23
+ <h5 style="margin: 0;"> <a href="https://sharegpt4video.github.io/">[Project Page]</a> <a href="https://github.com/ShareGPT4Omni/ShareGPT4Video">[Code]</a> <a href="https://arxiv.org/abs/2406.04325v1">[Paper]</a>
24
+ </div>
25
+ </div>
26
+ """)
27
+
28
+ block_css = """
29
+ #buttons button {
30
+ min-width: min(120px,100%);
31
+ }
32
+ """
33
+
34
+
35
+ learn_more_markdown = ("""
36
+ ### License
37
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
38
+ """)
39
+
40
+
41
+ def create_frame_grid(img_array, interval_width=50):
42
+ n, h, w, c = img_array.shape
43
+ grid_size = int(np.ceil(np.sqrt(n)))
44
+
45
+ horizontal_band = np.ones((h, interval_width, c),
46
+ dtype=img_array.dtype) * 255
47
+ vertical_band = np.ones((interval_width, w + (grid_size - 1)
48
+ * (w + interval_width), c), dtype=img_array.dtype) * 255
49
+
50
+ rows = []
51
+ for i in range(grid_size):
52
+ row_frames = []
53
+ for j in range(grid_size):
54
+ idx = i * grid_size + j
55
+ if idx < n:
56
+ frame = img_array[idx]
57
+ else:
58
+ frame = np.ones_like(img_array[0]) * 255
59
+ if j > 0:
60
+ row_frames.append(horizontal_band)
61
+ row_frames.append(frame)
62
+ combined_row = np.concatenate(row_frames, axis=1)
63
+ if i > 0:
64
+ rows.append(vertical_band)
65
+ rows.append(combined_row)
66
+
67
+ final_grid = np.concatenate(rows, axis=0)
68
+ return final_grid
69
+
70
+
71
+ def resize_image_grid(image, max_length=1920):
72
+ width, height = image.size
73
+ if max(width, height) > max_length:
74
+ if width > height:
75
+ scale = max_length / width
76
+ else:
77
+ scale = max_length / height
78
+
79
+ new_width = int(width * scale)
80
+ new_height = int(height * scale)
81
+
82
+ img_resized = image.resize((new_width, new_height), Image.BILINEAR)
83
+ else:
84
+ img_resized = image
85
+ return img_resized
86
+
87
+
88
+ def get_index(num_frames, num_segments):
89
+ seg_size = float(num_frames - 1) / num_segments
90
+ start = int(seg_size / 2)
91
+ offsets = np.array([
92
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
93
+ ])
94
+ return offsets
95
+
96
+
97
+ def load_video(video_path, num_segments=8, return_msg=False, num_frames=4):
98
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
99
+ num_frames = len(vr)
100
+ frame_indices = get_index(num_frames, num_segments)
101
+ img_array = vr.get_batch(frame_indices).asnumpy()
102
+ img_grid = create_frame_grid(img_array, 50)
103
+ img_grid = Image.fromarray(img_grid).convert("RGB")
104
+ img_grid = resize_image_grid(img_grid)
105
+ if return_msg:
106
+ fps = float(vr.get_avg_fps())
107
+ sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
108
+ # " " should be added in the start and end
109
+ msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
110
+ return img_grid, msg
111
+ else:
112
+ return img_grid
113
+
114
+
115
+ def video_answer(prompt, model, processor, tokenizer, img_grid, do_sample=True,
116
+ max_new_tokens=200, num_beams=1, top_p=0.9,
117
+ temperature=1.0, print_res=False, **kwargs):
118
+ if not isinstance(img_grid, (list, tuple)):
119
+ img_grid = [img_grid]
120
+ image_size = img_grid[0].size
121
+ image_tensor = process_images(img_grid, processor, model.config)[0]
122
+ input_ids = tokenizer_image_token(
123
+ prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
124
+ input_ids = input_ids.unsqueeze(0).to(
125
+ device=model.device, non_blocking=True)
126
+ pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token is not None else tokenizer.eos_token_id
127
+
128
+ with torch.inference_mode():
129
+ output_ids = model.generate(
130
+ input_ids,
131
+ images=image_tensor.to(
132
+ dtype=torch.float16, device=model.device, non_blocking=True),
133
+ image_sizes=[image_size],
134
+ do_sample=do_sample,
135
+ temperature=temperature,
136
+ top_p=top_p,
137
+ num_beams=num_beams,
138
+ max_new_tokens=max_new_tokens,
139
+ pad_token_id=pad_token_id,
140
+ use_cache=True,
141
+ **kwargs)
142
+ outputs = tokenizer.batch_decode(
143
+ output_ids, skip_special_tokens=True)[0].strip()
144
+ if print_res: # debug usage
145
+ print('### PROMPTING LM WITH: ', prompt)
146
+ print('### LM OUTPUT TEXT: ', outputs)
147
+
148
+ return outputs
149
+
150
+
151
+ class Chat:
152
+ def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', num_frames=16):
153
+ disable_torch_init()
154
+ model_name = get_model_name_from_path(model_path)
155
+ self.tokenizer, self.model, self.processor, context_len = load_pretrained_model(
156
+ model_path, model_base, model_name,
157
+ load_8bit, load_4bit,
158
+ device=device)
159
+ self.model.eval()
160
+ self.conv_mode = conv_mode
161
+ self.device = self.model.device
162
+ self.num_frames = num_frames
163
+ self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. Answer concisely with overall content and context of the video, highlighting any significant events, characters, or objects that appear throughout the frames."
164
+
165
+ def get_prompt(self, qs, state):
166
+ state.append_message(state.roles[0], qs)
167
+ state.append_message(state.roles[1], None)
168
+ return state
169
+
170
+ @torch.inference_mode()
171
+ def generate(self, vid_path: list, prompt: str, first_run: bool, state):
172
+ if self.num_frames != 0:
173
+ vid, msg = load_video(
174
+ vid_path, num_segments=self.num_frames, return_msg=True)
175
+ else:
176
+ vid, msg = None, 'num_frames is 0, not inputing image'
177
+ img_grid = vid
178
+ if self.pre_query_prompt is not None:
179
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + self.pre_query_prompt + prompt
180
+ else:
181
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
182
+ state = self.get_prompt(prompt, state)
183
+ prompt = state.get_prompt()
184
+ llm_response = video_answer(prompt, model=self.model, processor=self.processor, tokenizer=self.tokenizer,
185
+ do_sample=True, temperature=0.1, img_grid=img_grid, max_new_tokens=1024, print_res=True)
186
+ return llm_response, state
llava/train/llava_trainer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from torch.utils.data import Sampler
6
+
7
+ from transformers import Trainer
8
+ from transformers.trainer import (
9
+ is_sagemaker_mp_enabled,
10
+ get_parameter_names,
11
+ has_length,
12
+ ALL_LAYERNORM_LAYERS,
13
+ logger,
14
+ )
15
+ from typing import List, Optional
16
+
17
+
18
+ def maybe_zero_3(param, ignore_status=False, name=None):
19
+ from deepspeed import zero
20
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
21
+ if hasattr(param, "ds_id"):
22
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
23
+ if not ignore_status:
24
+ print(name, 'no ignore status')
25
+ with zero.GatheredParameters([param]):
26
+ param = param.data.detach().cpu().clone()
27
+ else:
28
+ param = param.detach().cpu().clone()
29
+ return param
30
+
31
+
32
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
33
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
34
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
35
+ return to_return
36
+
37
+
38
+ def split_to_even_chunks(indices, lengths, num_chunks):
39
+ """
40
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
41
+ """
42
+
43
+ if len(indices) % num_chunks != 0:
44
+ return [indices[i::num_chunks] for i in range(num_chunks)]
45
+
46
+ num_indices_per_chunk = len(indices) // num_chunks
47
+
48
+ chunks = [[] for _ in range(num_chunks)]
49
+ chunks_lengths = [0 for _ in range(num_chunks)]
50
+ for index in indices:
51
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
52
+ chunks[shortest_chunk].append(index)
53
+ chunks_lengths[shortest_chunk] += lengths[index]
54
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
55
+ chunks_lengths[shortest_chunk] = float("inf")
56
+
57
+ return chunks
58
+
59
+
60
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
61
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
62
+ assert all(l != 0 for l in lengths), "Should not have zero length."
63
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
64
+ # all samples are in the same modality
65
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
66
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
67
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
68
+
69
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
70
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
71
+ megabatch_size = world_size * batch_size
72
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
73
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
74
+
75
+ last_mm = mm_megabatches[-1]
76
+ last_lang = lang_megabatches[-1]
77
+ additional_batch = last_mm + last_lang
78
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
79
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
80
+ megabatches = [megabatches[i] for i in megabatch_indices]
81
+
82
+ if len(additional_batch) > 0:
83
+ megabatches.append(sorted(additional_batch))
84
+
85
+ return [i for megabatch in megabatches for i in megabatch]
86
+
87
+
88
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
89
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
90
+ indices = torch.randperm(len(lengths), generator=generator)
91
+ megabatch_size = world_size * batch_size
92
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
93
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
94
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
95
+
96
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
97
+
98
+
99
+ class LengthGroupedSampler(Sampler):
100
+ r"""
101
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
102
+ keeping a bit of randomness.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ batch_size: int,
108
+ world_size: int,
109
+ lengths: Optional[List[int]] = None,
110
+ generator=None,
111
+ group_by_modality: bool = False,
112
+ ):
113
+ if lengths is None:
114
+ raise ValueError("Lengths must be provided.")
115
+
116
+ self.batch_size = batch_size
117
+ self.world_size = world_size
118
+ self.lengths = lengths
119
+ self.generator = generator
120
+ self.group_by_modality = group_by_modality
121
+
122
+ def __len__(self):
123
+ return len(self.lengths)
124
+
125
+ def __iter__(self):
126
+ if self.group_by_modality:
127
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
128
+ else:
129
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
130
+ return iter(indices)
131
+
132
+
133
+ class LLaVATrainer(Trainer):
134
+
135
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
136
+ if self.train_dataset is None or not has_length(self.train_dataset):
137
+ return None
138
+
139
+ if self.args.group_by_modality_length:
140
+ lengths = self.train_dataset.modality_lengths
141
+ return LengthGroupedSampler(
142
+ self.args.train_batch_size,
143
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
144
+ lengths=lengths,
145
+ group_by_modality=True,
146
+ )
147
+ else:
148
+ return super()._get_train_sampler()
149
+
150
+ def create_optimizer(self):
151
+ """
152
+ Setup the optimizer.
153
+
154
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
155
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
156
+ """
157
+ if is_sagemaker_mp_enabled():
158
+ return super().create_optimizer()
159
+
160
+ opt_model = self.model
161
+
162
+ if self.optimizer is None:
163
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
164
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
165
+ if self.args.mm_projector_lr is not None:
166
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
167
+ if self.args.mm_vision_tower_lr is not None:
168
+ vision_tower_parameters = [
169
+ name for name, _ in opt_model.named_parameters() if "vision_tower" in name]
170
+ optimizer_grouped_parameters = [
171
+ {
172
+ "params": [
173
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad)
174
+ ],
175
+ "weight_decay": self.args.weight_decay,
176
+ },
177
+ {
178
+ "params": [
179
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad)
180
+ ],
181
+ "weight_decay": self.args.weight_decay,
182
+ "lr": self.args.mm_vision_tower_lr,
183
+ },
184
+ {
185
+ "params": [
186
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad)
187
+ ],
188
+ "weight_decay": 0.0,
189
+ },
190
+ {
191
+ "params": [
192
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad)
193
+ ],
194
+ "weight_decay": 0.0,
195
+ "lr": self.args.mm_vision_tower_lr,
196
+ },
197
+ {
198
+ "params": [
199
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
200
+ ],
201
+ "weight_decay": self.args.weight_decay,
202
+ "lr": self.args.mm_projector_lr,
203
+ },
204
+ {
205
+ "params": [
206
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
207
+ ],
208
+ "weight_decay": 0.0,
209
+ "lr": self.args.mm_projector_lr,
210
+ },
211
+ ]
212
+ else:
213
+ optimizer_grouped_parameters = [
214
+ {
215
+ "params": [
216
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
217
+ ],
218
+ "weight_decay": self.args.weight_decay,
219
+ },
220
+ {
221
+ "params": [
222
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
223
+ ],
224
+ "weight_decay": 0.0,
225
+ },
226
+ {
227
+ "params": [
228
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
229
+ ],
230
+ "weight_decay": self.args.weight_decay,
231
+ "lr": self.args.mm_projector_lr,
232
+ },
233
+ {
234
+ "params": [
235
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
236
+ ],
237
+ "weight_decay": 0.0,
238
+ "lr": self.args.mm_projector_lr,
239
+ },
240
+ ]
241
+ else:
242
+ optimizer_grouped_parameters = [
243
+ {
244
+ "params": [
245
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
246
+ ],
247
+ "weight_decay": self.args.weight_decay,
248
+ },
249
+ {
250
+ "params": [
251
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
252
+ ],
253
+ "weight_decay": 0.0,
254
+ },
255
+ ]
256
+
257
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
258
+
259
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
260
+ if optimizer_cls.__name__ == "Adam8bit":
261
+ import bitsandbytes
262
+
263
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
264
+
265
+ skipped = 0
266
+ for module in opt_model.modules():
267
+ if isinstance(module, nn.Embedding):
268
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
269
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
270
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
271
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
272
+ logger.info(f"skipped: {skipped/2**20}M params")
273
+
274
+ return self.optimizer
275
+
276
+ def _save_checkpoint(self, model, trial, metrics=None):
277
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
278
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
279
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
280
+
281
+ run_dir = self._get_output_dir(trial=trial)
282
+ output_dir = os.path.join(run_dir, checkpoint_folder)
283
+
284
+ # Only save Adapter
285
+ keys_to_match = ['mm_projector', 'vision_resampler']
286
+ if getattr(self.args, "use_im_start_end", False):
287
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
288
+
289
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
290
+
291
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
292
+ self.model.config.save_pretrained(output_dir)
293
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
294
+ else:
295
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
296
+
297
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
298
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
299
+ pass
300
+ else:
301
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
llava/train/train.py ADDED
@@ -0,0 +1,1323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import copy
17
+ import json
18
+ import logging
19
+ import os
20
+ import pathlib
21
+ from dataclasses import dataclass, field
22
+ from typing import Dict, List, Optional, Sequence
23
+
24
+ import tokenizers
25
+ import torch
26
+ import transformers
27
+ from packaging import version
28
+ from PIL import Image
29
+ from torch.utils.data import Dataset
30
+
31
+ from llava import conversation as conversation_lib
32
+ from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
33
+ DEFAULT_IMAGE_TOKEN, IGNORE_INDEX,
34
+ IMAGE_TOKEN_INDEX)
35
+ from llava.mm_utils import process_anyres_image, tokenizer_image_token
36
+ from llava.model import *
37
+ from llava.train.llava_trainer import LLaVATrainer
38
+
39
+ local_rank = None
40
+
41
+
42
+ def rank0_print(*args):
43
+ if local_rank == 0:
44
+ print(*args)
45
+
46
+
47
+ IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(
48
+ tokenizers.__version__) >= version.parse('0.14')
49
+
50
+
51
+ @dataclass
52
+ class ModelArguments:
53
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
54
+ version: Optional[str] = field(default="v0")
55
+ freeze_backbone: bool = field(default=False)
56
+ tune_mm_mlp_adapter: bool = field(default=False)
57
+ vision_tower: Optional[str] = field(default=None)
58
+ mm_vision_select_layer: Optional[int] = field(
59
+ default=-1) # default to the last layer
60
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
61
+ mm_projector_type: Optional[str] = field(default='linear')
62
+ mm_use_im_start_end: bool = field(default=False)
63
+ mm_use_im_patch_token: bool = field(default=True)
64
+ mm_patch_merge_type: Optional[str] = field(default='flat')
65
+ mm_vision_select_feature: Optional[str] = field(default="patch")
66
+
67
+
68
+ @dataclass
69
+ class DataArguments:
70
+ data_path: str = field(default=None,
71
+ metadata={"help": "Path to the training data."})
72
+ lazy_preprocess: bool = False
73
+ is_multimodal: bool = False
74
+ image_folder: Optional[str] = field(default=None)
75
+ image_aspect_ratio: str = 'square'
76
+
77
+
78
+ @dataclass
79
+ class TrainingArguments(transformers.TrainingArguments):
80
+ cache_dir: Optional[str] = field(default=None)
81
+ optim: str = field(default="adamw_torch")
82
+ remove_unused_columns: bool = field(default=False)
83
+ freeze_mm_mlp_adapter: bool = field(default=False)
84
+ unfreeze_mm_vision_tower: bool = field(default=False)
85
+ mpt_attn_impl: Optional[str] = field(default="triton")
86
+ model_max_length: int = field(
87
+ default=512,
88
+ metadata={
89
+ "help":
90
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
91
+ },
92
+ )
93
+ double_quant: bool = field(
94
+ default=True,
95
+ metadata={
96
+ "help": "Compress the quantization statistics through double quantization."}
97
+ )
98
+ quant_type: str = field(
99
+ default="nf4",
100
+ metadata={
101
+ "help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
102
+ )
103
+ bits: int = field(
104
+ default=16,
105
+ metadata={"help": "How many bits to use."}
106
+ )
107
+ lora_enable: bool = False
108
+ lora_r: int = 64
109
+ lora_alpha: int = 16
110
+ lora_dropout: float = 0.05
111
+ lora_weight_path: str = ""
112
+ lora_bias: str = "none"
113
+ lora_qv_proj_only: bool = False
114
+ mm_projector_lr: Optional[float] = None
115
+ mm_vision_tower_lr: Optional[float] = None
116
+ group_by_modality_length: bool = field(default=False)
117
+
118
+
119
+ def maybe_zero_3(param, ignore_status=False, name=None):
120
+ from deepspeed import zero
121
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
122
+ if hasattr(param, "ds_id"):
123
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
124
+ if not ignore_status:
125
+ logging.warning(
126
+ f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
127
+ with zero.GatheredParameters([param]):
128
+ param = param.data.detach().cpu().clone()
129
+ else:
130
+ param = param.detach().cpu().clone()
131
+ return param
132
+
133
+
134
+ # Borrowed from peft.utils.get_peft_model_state_dict
135
+ def get_peft_state_maybe_zero_3(named_params, bias):
136
+ if bias == "none":
137
+ to_return = {k: t for k, t in named_params if "lora_" in k}
138
+ elif bias == "all":
139
+ to_return = {k: t for k,
140
+ t in named_params if "lora_" in k or "bias" in k}
141
+ elif bias == "lora_only":
142
+ to_return = {}
143
+ maybe_lora_bias = {}
144
+ lora_bias_names = set()
145
+ for k, t in named_params:
146
+ if "lora_" in k:
147
+ to_return[k] = t
148
+ bias_name = k.split("lora_")[0] + "bias"
149
+ lora_bias_names.add(bias_name)
150
+ elif "bias" in k:
151
+ maybe_lora_bias[k] = t
152
+ for k, t in maybe_lora_bias:
153
+ if bias_name in lora_bias_names:
154
+ to_return[bias_name] = t
155
+ else:
156
+ raise NotImplementedError
157
+ to_return = {k: maybe_zero_3(v, ignore_status=True)
158
+ for k, v in to_return.items()}
159
+ return to_return
160
+
161
+
162
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
163
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
164
+ if require_grad_only:
165
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
166
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu()
167
+ for k, v in to_return.items()}
168
+ return to_return
169
+
170
+
171
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
172
+ to_return = {k: t for k, t in named_params if any(
173
+ key_match in k for key_match in keys_to_match)}
174
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu()
175
+ for k, v in to_return.items()}
176
+ return to_return
177
+
178
+
179
+ def get_vision_tower_state_maybe_zero_3(named_params, keys_to_match=['']):
180
+ to_return = {k: t for k, t in named_params if any(
181
+ key_match in k for key_match in keys_to_match)}
182
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu()
183
+ for k, v in to_return.items()}
184
+ return to_return
185
+
186
+
187
+ def find_all_linear_names(model, qv_proj_only=False):
188
+ if qv_proj_only:
189
+ rank0_print('Only add LoRA to QV proj')
190
+ return ['q_proj', 'v_proj']
191
+ cls = torch.nn.Linear
192
+ lora_module_names = set()
193
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
194
+ for name, module in model.named_modules():
195
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
196
+ continue
197
+ if isinstance(module, cls):
198
+ names = name.split('.')
199
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
200
+
201
+ if 'lm_head' in lora_module_names: # needed for 16-bit
202
+ lora_module_names.remove('lm_head')
203
+ return list(lora_module_names)
204
+
205
+
206
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
207
+ output_dir: str):
208
+ """Collects the state dict and dump to disk."""
209
+
210
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
211
+ # Only save Adapter
212
+ keys_to_match = ['mm_projector']
213
+ if getattr(trainer.args, "use_im_start_end", False):
214
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
215
+
216
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(
217
+ trainer.model.named_parameters(), keys_to_match)
218
+ trainer.model.config.save_pretrained(output_dir)
219
+
220
+ current_folder = output_dir.split('/')[-1]
221
+ parent_folder = os.path.dirname(output_dir)
222
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
223
+ if current_folder.startswith('checkpoint-'):
224
+ mm_projector_folder = os.path.join(
225
+ parent_folder, "mm_projector")
226
+ os.makedirs(mm_projector_folder, exist_ok=True)
227
+ torch.save(weight_to_save, os.path.join(
228
+ mm_projector_folder, f'{current_folder}.bin'))
229
+ else:
230
+ torch.save(weight_to_save, os.path.join(
231
+ output_dir, f'mm_projector.bin'))
232
+ return
233
+
234
+ if getattr(trainer.args, "unfreeze_mm_vision_tower", False):
235
+ if trainer.deepspeed:
236
+ torch.cuda.synchronize()
237
+ mm_vision_tower_folder = os.path.join(output_dir, 'vision_tower')
238
+ os.makedirs(mm_vision_tower_folder, exist_ok=True)
239
+ trainer.model.get_vision_tower().image_processor.save_pretrained(mm_vision_tower_folder)
240
+ trainer.model.get_vision_tower().vision_tower.vision_model.config.save_pretrained(
241
+ mm_vision_tower_folder)
242
+ weight_to_save = get_vision_tower_state_maybe_zero_3(
243
+ trainer.model.get_vision_tower().vision_tower.named_parameters())
244
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
245
+ torch.save(weight_to_save, os.path.join(
246
+ mm_vision_tower_folder, 'pytorch_model.bin'))
247
+
248
+ if getattr(trainer.model.model, 'vision_tower', None) is not None:
249
+ del trainer.model.model.vision_tower
250
+
251
+ if trainer.deepspeed:
252
+ torch.cuda.synchronize()
253
+ trainer.save_model(output_dir)
254
+ return
255
+
256
+ state_dict = trainer.model.state_dict()
257
+ if trainer.args.should_save:
258
+ cpu_state_dict = {
259
+ key: value.cpu()
260
+ for key, value in state_dict.items()
261
+ }
262
+ del state_dict
263
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
264
+
265
+
266
+ def smart_tokenizer_and_embedding_resize(
267
+ special_tokens_dict: Dict,
268
+ tokenizer: transformers.PreTrainedTokenizer,
269
+ model: transformers.PreTrainedModel,
270
+ ):
271
+ """Resize tokenizer and embedding.
272
+
273
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
274
+ """
275
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
276
+ model.resize_token_embeddings(len(tokenizer))
277
+
278
+ if num_new_tokens > 0:
279
+ input_embeddings = model.get_input_embeddings().weight.data
280
+ output_embeddings = model.get_output_embeddings().weight.data
281
+
282
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
283
+ dim=0, keepdim=True)
284
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
285
+ dim=0, keepdim=True)
286
+
287
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
288
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
289
+
290
+
291
+ def _tokenize_fn(strings: Sequence[str],
292
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
293
+ """Tokenize a list of strings."""
294
+ tokenized_list = [
295
+ tokenizer(
296
+ text,
297
+ return_tensors="pt",
298
+ padding="longest",
299
+ max_length=tokenizer.model_max_length,
300
+ truncation=True,
301
+ ) for text in strings
302
+ ]
303
+ input_ids = labels = [
304
+ tokenized.input_ids[0] for tokenized in tokenized_list
305
+ ]
306
+ input_ids_lens = labels_lens = [
307
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
308
+ for tokenized in tokenized_list
309
+ ]
310
+ return dict(
311
+ input_ids=input_ids,
312
+ labels=labels,
313
+ input_ids_lens=input_ids_lens,
314
+ labels_lens=labels_lens,
315
+ )
316
+
317
+
318
+ def _mask_targets(target, tokenized_lens, speakers):
319
+ # cur_idx = 0
320
+ cur_idx = tokenized_lens[0]
321
+ tokenized_lens = tokenized_lens[1:]
322
+ target[:cur_idx] = IGNORE_INDEX
323
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
324
+ if speaker == "human":
325
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
326
+ cur_idx += tokenized_len
327
+
328
+
329
+ def _add_speaker_and_signal(header, source, get_conversation=True):
330
+ """Add speaker and start/end signal on each round."""
331
+ BEGIN_SIGNAL = "### "
332
+ END_SIGNAL = "\n"
333
+ conversation = header
334
+ for sentence in source:
335
+ from_str = sentence["from"]
336
+ if from_str.lower() == "human":
337
+ from_str = conversation_lib.default_conversation.roles[0]
338
+ elif from_str.lower() == "gpt":
339
+ from_str = conversation_lib.default_conversation.roles[1]
340
+ else:
341
+ from_str = 'unknown'
342
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
343
+ sentence["value"] + END_SIGNAL)
344
+ if get_conversation:
345
+ conversation += sentence["value"]
346
+ conversation += BEGIN_SIGNAL
347
+ return conversation
348
+
349
+
350
+ def preprocess_multimodal(
351
+ sources: Sequence[str],
352
+ data_args: DataArguments
353
+ ) -> Dict:
354
+ is_multimodal = data_args.is_multimodal
355
+ if not is_multimodal:
356
+ return sources
357
+
358
+ for source in sources:
359
+ for sentence in source:
360
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
361
+ sentence['value'] = sentence['value'].replace(
362
+ DEFAULT_IMAGE_TOKEN, '').strip()
363
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + \
364
+ '\n' + sentence['value']
365
+ sentence['value'] = sentence['value'].strip()
366
+ if "mmtag" in conversation_lib.default_conversation.version:
367
+ sentence['value'] = sentence['value'].replace(
368
+ DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
369
+ replace_token = DEFAULT_IMAGE_TOKEN
370
+ if data_args.mm_use_im_start_end:
371
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
372
+ sentence["value"] = sentence["value"].replace(
373
+ DEFAULT_IMAGE_TOKEN, replace_token)
374
+
375
+ return sources
376
+
377
+
378
+ def preprocess_llama_2(
379
+ sources,
380
+ tokenizer: transformers.PreTrainedTokenizer,
381
+ has_image: bool = False
382
+ ) -> Dict:
383
+ conv = conversation_lib.default_conversation.copy()
384
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
385
+
386
+ # Apply prompt templates
387
+ conversations = []
388
+ for i, source in enumerate(sources):
389
+ if roles[source[0]["from"]] != conv.roles[0]:
390
+ # Skip the first one if it is not from human
391
+ source = source[1:]
392
+
393
+ conv.messages = []
394
+ for j, sentence in enumerate(source):
395
+ role = roles[sentence["from"]]
396
+ assert role == conv.roles[j % 2], f"{i}"
397
+ conv.append_message(role, sentence["value"])
398
+ conversations.append(conv.get_prompt())
399
+
400
+ # Tokenize conversations
401
+
402
+ if has_image:
403
+ input_ids = torch.stack([tokenizer_image_token(
404
+ prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
405
+ else:
406
+ input_ids = tokenizer(
407
+ conversations,
408
+ return_tensors="pt",
409
+ padding="longest",
410
+ max_length=tokenizer.model_max_length,
411
+ truncation=True,
412
+ ).input_ids
413
+
414
+ targets = input_ids.clone()
415
+
416
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
417
+
418
+ # Mask targets
419
+ sep = "[/INST] "
420
+ for conversation, target in zip(conversations, targets):
421
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
422
+
423
+ rounds = conversation.split(conv.sep2)
424
+ cur_len = 1
425
+ target[:cur_len] = IGNORE_INDEX
426
+ for i, rou in enumerate(rounds):
427
+ if rou == "":
428
+ break
429
+
430
+ parts = rou.split(sep)
431
+ if len(parts) != 2:
432
+ break
433
+ parts[0] += sep
434
+
435
+ if has_image:
436
+ round_len = len(tokenizer_image_token(rou, tokenizer))
437
+ instruction_len = len(
438
+ tokenizer_image_token(parts[0], tokenizer)) - 2
439
+ else:
440
+ round_len = len(tokenizer(rou).input_ids)
441
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
442
+
443
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
444
+
445
+ cur_len += round_len
446
+ target[cur_len:] = IGNORE_INDEX
447
+
448
+ if cur_len < tokenizer.model_max_length:
449
+ if cur_len != total_len:
450
+ target[:] = IGNORE_INDEX
451
+ print(
452
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
453
+ f" (ignored)"
454
+ )
455
+
456
+ return dict(
457
+ input_ids=input_ids,
458
+ labels=targets,
459
+ )
460
+
461
+
462
+ def preprocess_llama3(
463
+ sources,
464
+ tokenizer: transformers.PreTrainedTokenizer,
465
+ has_image: bool = False
466
+ ) -> Dict:
467
+ conv = conversation_lib.default_conversation.copy()
468
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
469
+
470
+ # Apply prompt templates
471
+ conversations = []
472
+ for i, source in enumerate(sources):
473
+ if roles[source[0]["from"]] != conv.roles[0]:
474
+ # Skip the first one if it is not from human
475
+ source = source[1:]
476
+
477
+ conv.messages = []
478
+ for j, sentence in enumerate(source):
479
+ role = roles[sentence["from"]]
480
+ assert role == conv.roles[j % 2], f"{i}"
481
+ conv.append_message(role, sentence["value"])
482
+ conversations.append(conv.get_prompt())
483
+
484
+ # Tokenize conversations
485
+
486
+ if has_image:
487
+ input_ids = torch.stack(
488
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
489
+ else:
490
+ input_ids = tokenizer(
491
+ conversations,
492
+ return_tensors="pt",
493
+ padding="longest",
494
+ max_length=tokenizer.model_max_length,
495
+ truncation=True,
496
+ ).input_ids
497
+
498
+ targets = input_ids.clone()
499
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
500
+
501
+ # Mask targets
502
+ sep = conv.sep + conv.roles[1]
503
+ for conversation, target in zip(conversations, targets):
504
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
505
+
506
+ rounds = conversation.split(conv.sep)
507
+ re_rounds = [conv.sep.join(rounds[:3])]
508
+ for conv_idx in range(3, len(rounds), 2):
509
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2]))
510
+ cur_len = 0
511
+ target[:cur_len] = IGNORE_INDEX
512
+ for i, rou in enumerate(re_rounds):
513
+ if rou == "":
514
+ break
515
+
516
+ parts = rou.split(sep)
517
+ if len(parts) != 2:
518
+ break
519
+ parts[0] += sep
520
+
521
+ if has_image:
522
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + 1
523
+ instruction_len = len(
524
+ tokenizer_image_token(parts[0], tokenizer))
525
+ else:
526
+ round_len = len(tokenizer(rou).input_ids) + 1
527
+ instruction_len = len(tokenizer(parts[0]).input_ids)
528
+
529
+ if i > 0:
530
+ round_len -= 1
531
+ instruction_len -= 1
532
+
533
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
534
+
535
+ cur_len += round_len
536
+ target[cur_len:] = IGNORE_INDEX
537
+
538
+ if cur_len < tokenizer.model_max_length:
539
+ if cur_len != total_len:
540
+ target[:] = IGNORE_INDEX
541
+ print(
542
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
543
+ f" (ignored)"
544
+ )
545
+
546
+ return dict(
547
+ input_ids=input_ids,
548
+ labels=targets,
549
+ )
550
+
551
+
552
+ def preprocess_yi(
553
+ sources,
554
+ tokenizer: transformers.PreTrainedTokenizer,
555
+ has_image: bool = False
556
+ ) -> Dict:
557
+ conv = conversation_lib.default_conversation.copy()
558
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
559
+
560
+ # Apply prompt templates
561
+ conversations = []
562
+ for i, source in enumerate(sources):
563
+ if roles[source[0]["from"]] != conv.roles[0]:
564
+ # Skip the first one if it is not from human
565
+ source = source[1:]
566
+
567
+ conv.messages = []
568
+ for j, sentence in enumerate(source):
569
+ role = roles[sentence["from"]]
570
+ assert role == conv.roles[j % 2], f"{i}"
571
+ conv.append_message(role, sentence["value"])
572
+ conversations.append(conv.get_prompt())
573
+
574
+ # Tokenize conversations
575
+
576
+ if has_image:
577
+ input_ids = torch.stack(
578
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
579
+ else:
580
+ input_ids = tokenizer(
581
+ conversations,
582
+ return_tensors="pt",
583
+ padding="longest",
584
+ max_length=tokenizer.model_max_length,
585
+ truncation=True,
586
+ ).input_ids
587
+
588
+ targets = input_ids.clone()
589
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
590
+
591
+ # Mask targets
592
+ sep = conv.sep + conv.roles[1]
593
+ for conversation, target in zip(conversations, targets):
594
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
595
+
596
+ rounds = conversation.split(conv.sep)
597
+ re_rounds = [conv.sep.join(rounds[:3])]
598
+ for conv_idx in range(3, len(rounds), 2):
599
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2]))
600
+ cur_len = 0
601
+ target[:cur_len] = IGNORE_INDEX
602
+ for i, rou in enumerate(re_rounds):
603
+ if rou == "":
604
+ break
605
+
606
+ parts = rou.split(sep)
607
+ if len(parts) != 2:
608
+ break
609
+ parts[0] += sep
610
+
611
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
612
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
613
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
614
+
615
+ cur_len += round_len
616
+ target[cur_len:] = IGNORE_INDEX
617
+
618
+ if cur_len < tokenizer.model_max_length:
619
+ if cur_len != total_len:
620
+ target[:] = IGNORE_INDEX
621
+ print(
622
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
623
+ f" (ignored)"
624
+ )
625
+
626
+ return dict(
627
+ input_ids=input_ids,
628
+ labels=targets,
629
+ )
630
+
631
+
632
+ def preprocess_v1(
633
+ sources,
634
+ tokenizer: transformers.PreTrainedTokenizer,
635
+ has_image: bool = False
636
+ ) -> Dict:
637
+ conv = conversation_lib.default_conversation.copy()
638
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
639
+
640
+ # Apply prompt templates
641
+ conversations = []
642
+ for i, source in enumerate(sources):
643
+ if roles[source[0]["from"]] != conv.roles[0]:
644
+ # Skip the first one if it is not from human
645
+ source = source[1:]
646
+
647
+ conv.messages = []
648
+ for j, sentence in enumerate(source):
649
+ role = roles[sentence["from"]]
650
+ assert role == conv.roles[j % 2], f"{i}"
651
+ conv.append_message(role, sentence["value"])
652
+ conversations.append(conv.get_prompt())
653
+
654
+ # Tokenize conversations
655
+
656
+ if has_image:
657
+ input_ids = torch.stack([tokenizer_image_token(
658
+ prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
659
+ else:
660
+ input_ids = tokenizer(
661
+ conversations,
662
+ return_tensors="pt",
663
+ padding="longest",
664
+ max_length=tokenizer.model_max_length,
665
+ truncation=True,
666
+ ).input_ids
667
+
668
+ targets = input_ids.clone()
669
+
670
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
671
+
672
+ # Mask targets
673
+ sep = conv.sep + conv.roles[1] + ": "
674
+ for conversation, target in zip(conversations, targets):
675
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
676
+
677
+ rounds = conversation.split(conv.sep2)
678
+ cur_len = 1
679
+ target[:cur_len] = IGNORE_INDEX
680
+ for i, rou in enumerate(rounds):
681
+ if rou == "":
682
+ break
683
+
684
+ parts = rou.split(sep)
685
+ if len(parts) != 2:
686
+ break
687
+ parts[0] += sep
688
+
689
+ if has_image:
690
+ round_len = len(tokenizer_image_token(rou, tokenizer))
691
+ instruction_len = len(
692
+ tokenizer_image_token(parts[0], tokenizer)) - 2
693
+ else:
694
+ round_len = len(tokenizer(rou).input_ids)
695
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
696
+
697
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
698
+ round_len -= 1
699
+ instruction_len -= 1
700
+
701
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
702
+
703
+ cur_len += round_len
704
+ target[cur_len:] = IGNORE_INDEX
705
+
706
+ if cur_len < tokenizer.model_max_length:
707
+ if cur_len != total_len:
708
+ target[:] = IGNORE_INDEX
709
+ print(
710
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
711
+ f" (ignored)"
712
+ )
713
+
714
+ return dict(
715
+ input_ids=input_ids,
716
+ labels=targets,
717
+ )
718
+
719
+
720
+ def preprocess_mpt(
721
+ sources,
722
+ tokenizer: transformers.PreTrainedTokenizer,
723
+ has_image: bool = False
724
+ ) -> Dict:
725
+ conv = conversation_lib.default_conversation.copy()
726
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
727
+
728
+ # Apply prompt templates
729
+ conversations = []
730
+ for i, source in enumerate(sources):
731
+ if roles[source[0]["from"]] != conv.roles[0]:
732
+ # Skip the first one if it is not from human
733
+ source = source[1:]
734
+
735
+ conv.messages = []
736
+ for j, sentence in enumerate(source):
737
+ role = roles[sentence["from"]]
738
+ assert role == conv.roles[j % 2], f"{i}"
739
+ conv.append_message(role, sentence["value"])
740
+ conversations.append(conv.get_prompt())
741
+
742
+ # Tokenize conversations
743
+
744
+ if has_image:
745
+ input_ids = torch.stack([tokenizer_image_token(
746
+ prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
747
+ else:
748
+ input_ids = tokenizer(
749
+ conversations,
750
+ return_tensors="pt",
751
+ padding="longest",
752
+ max_length=tokenizer.model_max_length,
753
+ truncation=True,
754
+ ).input_ids
755
+
756
+ targets = input_ids.clone()
757
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
758
+
759
+ # Mask targets
760
+ sep = conv.sep + conv.roles[1]
761
+ for conversation, target in zip(conversations, targets):
762
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
763
+
764
+ rounds = conversation.split(conv.sep)
765
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
766
+ for conv_idx in range(3, len(rounds), 2):
767
+ re_rounds.append(conv.sep.join(
768
+ rounds[conv_idx:conv_idx+2])) # user + gpt
769
+ cur_len = 0
770
+ target[:cur_len] = IGNORE_INDEX
771
+ for i, rou in enumerate(re_rounds):
772
+ if rou == "":
773
+ break
774
+
775
+ parts = rou.split(sep)
776
+ if len(parts) != 2:
777
+ break
778
+ parts[0] += sep
779
+ # not included <|im_end|>
780
+ if has_image:
781
+ round_len = len(tokenizer_image_token(rou, tokenizer))
782
+ instruction_len = len(
783
+ tokenizer_image_token(parts[0], tokenizer)) - 1
784
+ else:
785
+ round_len = len(tokenizer(rou).input_ids)
786
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
787
+
788
+ # include <|im_end|> for all rounds
789
+ # if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
790
+ if getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
791
+ round_len += 1
792
+ instruction_len += 1
793
+
794
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
795
+
796
+ cur_len += round_len
797
+ target[cur_len:] = IGNORE_INDEX
798
+
799
+ if cur_len < tokenizer.model_max_length:
800
+ if cur_len != total_len:
801
+ target[:] = IGNORE_INDEX
802
+ print(
803
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
804
+ f" (ignored)"
805
+ )
806
+
807
+ return dict(
808
+ input_ids=input_ids,
809
+ labels=targets,
810
+ )
811
+
812
+
813
+ def preprocess_plain(
814
+ sources: Sequence[str],
815
+ tokenizer: transformers.PreTrainedTokenizer,
816
+ ) -> Dict:
817
+ # add end signal and concatenate together
818
+ conversations = []
819
+ for source in sources:
820
+ assert len(source) == 2
821
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
822
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
823
+ conversation = source[0]['value'] + source[1]['value'] + \
824
+ conversation_lib.default_conversation.sep
825
+ conversations.append(conversation)
826
+ # tokenize conversations
827
+ input_ids = [tokenizer_image_token(
828
+ prompt, tokenizer, return_tensors='pt') for prompt in conversations]
829
+ targets = copy.deepcopy(input_ids)
830
+ for target, source in zip(targets, sources):
831
+ tokenized_len = len(tokenizer_image_token(
832
+ source[0]['value'], tokenizer))
833
+ target[:tokenized_len] = IGNORE_INDEX
834
+
835
+ return dict(input_ids=input_ids, labels=targets)
836
+
837
+
838
+ def preprocess(
839
+ sources: Sequence[str],
840
+ tokenizer: transformers.PreTrainedTokenizer,
841
+ has_image: bool = False
842
+ ) -> Dict:
843
+ """
844
+ Given a list of sources, each is a conversation list. This transform:
845
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
846
+ 2. Concatenate conversations together;
847
+ 3. Tokenize the concatenated conversation;
848
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
849
+ """
850
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
851
+ return preprocess_plain(sources, tokenizer)
852
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
853
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
854
+ if conversation_lib.default_conversation.version.startswith("v1"):
855
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
856
+ if conversation_lib.default_conversation.version == "mpt":
857
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
858
+ if conversation_lib.default_conversation.version in ["llama3", "llava_llama_3"]:
859
+ return preprocess_llama3(sources, tokenizer, has_image=has_image)
860
+ if conversation_lib.default_conversation.version == "yi":
861
+ return preprocess_yi(sources, tokenizer, has_image=has_image)
862
+ # add end signal and concatenate together
863
+ conversations = []
864
+ for source in sources:
865
+ header = f"{conversation_lib.default_conversation.system}\n\n"
866
+ conversation = _add_speaker_and_signal(header, source)
867
+ conversations.append(conversation)
868
+ # tokenize conversations
869
+
870
+ def get_tokenize_len(prompts):
871
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
872
+
873
+ if has_image:
874
+ input_ids = [tokenizer_image_token(
875
+ prompt, tokenizer, return_tensors='pt') for prompt in conversations]
876
+ else:
877
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
878
+ input_ids = conversations_tokenized["input_ids"]
879
+
880
+ targets = copy.deepcopy(input_ids)
881
+ for target, source in zip(targets, sources):
882
+ if has_image:
883
+ tokenized_lens = get_tokenize_len(
884
+ [header] + [s["value"] for s in source])
885
+ else:
886
+ tokenized_lens = _tokenize_fn(
887
+ [header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
888
+ speakers = [sentence["from"] for sentence in source]
889
+ _mask_targets(target, tokenized_lens, speakers)
890
+
891
+ return dict(input_ids=input_ids, labels=targets)
892
+
893
+
894
+ class LazySupervisedDataset(Dataset):
895
+ """Dataset for supervised fine-tuning."""
896
+
897
+ def __init__(self, data_path: str,
898
+ tokenizer: transformers.PreTrainedTokenizer,
899
+ data_args: DataArguments):
900
+ super(LazySupervisedDataset, self).__init__()
901
+ list_data_dict = json.load(open(data_path, "r"))
902
+
903
+ rank0_print("Formatting inputs...Skip in lazy mode")
904
+ self.tokenizer = tokenizer
905
+ self.list_data_dict = list_data_dict
906
+ self.data_args = data_args
907
+
908
+ def __len__(self):
909
+ return len(self.list_data_dict)
910
+
911
+ @property
912
+ def lengths(self):
913
+ length_list = []
914
+ for sample in self.list_data_dict:
915
+ img_tokens = 128 if 'image' in sample else 0
916
+ length_list.append(sum(len(conv['value'].split())
917
+ for conv in sample['conversations']) + img_tokens)
918
+ return length_list
919
+
920
+ @property
921
+ def modality_lengths(self):
922
+ length_list = []
923
+ for sample in self.list_data_dict:
924
+ cur_len = sum(len(conv['value'].split())
925
+ for conv in sample['conversations'])
926
+ cur_len = cur_len if 'image' in sample else -cur_len
927
+ length_list.append(cur_len)
928
+ return length_list
929
+
930
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
931
+ sources = self.list_data_dict[i]
932
+ if isinstance(i, int):
933
+ sources = [sources]
934
+ assert len(
935
+ sources) == 1, "Don't know why it is wrapped to a list" # FIXME
936
+ if 'image' in sources[0]:
937
+ image_file = self.list_data_dict[i]['image']
938
+ image_folder = self.data_args.image_folder
939
+ processor = self.data_args.image_processor
940
+ image = Image.open(os.path.join(
941
+ image_folder, image_file)).convert('RGB')
942
+ if self.data_args.image_aspect_ratio == 'pad':
943
+ def expand2square(pil_img, background_color):
944
+ width, height = pil_img.size
945
+ if width == height:
946
+ return pil_img
947
+ elif width > height:
948
+ result = Image.new(
949
+ pil_img.mode, (width, width), background_color)
950
+ result.paste(pil_img, (0, (width - height) // 2))
951
+ return result
952
+ else:
953
+ result = Image.new(
954
+ pil_img.mode, (height, height), background_color)
955
+ result.paste(pil_img, ((height - width) // 2, 0))
956
+ return result
957
+
958
+ image = expand2square(image, tuple(int(x * 255)
959
+ for x in processor.image_mean))
960
+ image_size = image.size
961
+ image = processor.preprocess(image, return_tensors='pt')[
962
+ 'pixel_values'][0]
963
+ elif self.data_args.image_aspect_ratio == "anyres":
964
+ image_size = image.size
965
+ image = process_anyres_image(
966
+ image, processor, self.data_args.image_grid_pinpoints)
967
+ else:
968
+ image_size = image.size
969
+ image = processor.preprocess(image, return_tensors='pt')[
970
+ 'pixel_values'][0]
971
+ sources = preprocess_multimodal(
972
+ copy.deepcopy([e["conversations"] for e in sources]),
973
+ self.data_args)
974
+ else:
975
+ sources = copy.deepcopy([e["conversations"] for e in sources])
976
+ data_dict = preprocess(
977
+ sources,
978
+ self.tokenizer,
979
+ has_image=('image' in self.list_data_dict[i]))
980
+ if isinstance(i, int):
981
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
982
+ labels=data_dict["labels"][0])
983
+
984
+ # image exist in the data
985
+ if 'image' in self.list_data_dict[i]:
986
+ data_dict['image'] = image
987
+ data_dict['image_size'] = image_size
988
+ elif self.data_args.is_multimodal:
989
+ # image does not exist in the data, but the model is multimodal
990
+ crop_size = self.data_args.image_processor.crop_size
991
+ data_dict['image'] = torch.zeros(
992
+ 3, crop_size['height'], crop_size['width'])
993
+ data_dict['image_size'] = (crop_size['height'], crop_size['width'])
994
+ return data_dict
995
+
996
+
997
+ @dataclass
998
+ class DataCollatorForSupervisedDataset(object):
999
+ """Collate examples for supervised fine-tuning."""
1000
+
1001
+ tokenizer: transformers.PreTrainedTokenizer
1002
+
1003
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
1004
+ input_ids, labels = tuple([instance[key] for instance in instances]
1005
+ for key in ("input_ids", "labels"))
1006
+ input_ids = torch.nn.utils.rnn.pad_sequence(
1007
+ input_ids,
1008
+ batch_first=True,
1009
+ padding_value=self.tokenizer.pad_token_id)
1010
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
1011
+ batch_first=True,
1012
+ padding_value=IGNORE_INDEX)
1013
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
1014
+ labels = labels[:, :self.tokenizer.model_max_length]
1015
+ batch = dict(
1016
+ input_ids=input_ids,
1017
+ labels=labels,
1018
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
1019
+ )
1020
+
1021
+ if 'image' in instances[0]:
1022
+ images = [instance['image'] for instance in instances]
1023
+ image_sizes = [instance['image_size'] for instance in instances]
1024
+ if all(x is not None and x.shape == images[0].shape for x in images):
1025
+ batch['images'] = torch.stack(images)
1026
+ else:
1027
+ batch['images'] = images
1028
+ batch['image_sizes'] = image_sizes
1029
+
1030
+ return batch
1031
+
1032
+
1033
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
1034
+ data_args) -> Dict:
1035
+ """Make dataset and collator for supervised fine-tuning."""
1036
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
1037
+ data_path=data_args.data_path,
1038
+ data_args=data_args)
1039
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
1040
+ return dict(train_dataset=train_dataset,
1041
+ eval_dataset=None,
1042
+ data_collator=data_collator)
1043
+
1044
+
1045
+ def unfreeze_vit(vision_tower):
1046
+ for _, p in vision_tower.named_parameters():
1047
+ p.requires_grad = True
1048
+
1049
+
1050
+ def format_bytes(size):
1051
+ billion = 10**9
1052
+ million = 10**6
1053
+
1054
+ if size >= billion:
1055
+ return f"{size / billion:.2f}B"
1056
+ elif size >= million:
1057
+ return f"{size / million:.2f}M"
1058
+ else:
1059
+ return f"{size} bytes"
1060
+
1061
+
1062
+ def train(attn_implementation=None):
1063
+ global local_rank
1064
+
1065
+ parser = transformers.HfArgumentParser(
1066
+ (ModelArguments, DataArguments, TrainingArguments))
1067
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
1068
+ local_rank = training_args.local_rank
1069
+ compute_dtype = (torch.float16 if training_args.fp16 else (
1070
+ torch.bfloat16 if training_args.bf16 else torch.float32))
1071
+
1072
+ bnb_model_from_pretrained_args = {}
1073
+ if training_args.bits in [4, 8]:
1074
+ from transformers import BitsAndBytesConfig
1075
+ bnb_model_from_pretrained_args.update(dict(
1076
+ device_map={"": training_args.device},
1077
+ load_in_4bit=training_args.bits == 4,
1078
+ load_in_8bit=training_args.bits == 8,
1079
+ quantization_config=BitsAndBytesConfig(
1080
+ load_in_4bit=training_args.bits == 4,
1081
+ load_in_8bit=training_args.bits == 8,
1082
+ llm_int8_skip_modules=["mm_projector"],
1083
+ llm_int8_threshold=6.0,
1084
+ llm_int8_has_fp16_weight=False,
1085
+ bnb_4bit_compute_dtype=compute_dtype,
1086
+ bnb_4bit_use_double_quant=training_args.double_quant,
1087
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
1088
+ )
1089
+ ))
1090
+ model_max_length_args = {}
1091
+ if 'llava-v1.6-8b' not in model_args.model_name_or_path:
1092
+ config = transformers.AutoConfig.from_pretrained(
1093
+ model_args.model_name_or_path, trust_remote_code=True)
1094
+ if config.max_position_embeddings < training_args.model_max_length:
1095
+ rank0_print(
1096
+ f'Set the max_position_embeddings from {config.max_position_embeddings} to {training_args.model_max_length}')
1097
+ model_max_length_args.update(
1098
+ {'max_position_embeddings': training_args.model_max_length})
1099
+ if model_args.vision_tower is not None:
1100
+ if 'mpt' in model_args.model_name_or_path:
1101
+ config = transformers.AutoConfig.from_pretrained(
1102
+ model_args.model_name_or_path, trust_remote_code=True)
1103
+ config.attn_config['attn_impl'] = training_args.mpt_attn_impl
1104
+ model = LlavaMptForCausalLM.from_pretrained(
1105
+ model_args.model_name_or_path,
1106
+ config=config,
1107
+ cache_dir=training_args.cache_dir,
1108
+ **bnb_model_from_pretrained_args
1109
+ )
1110
+ else:
1111
+ model = LlavaLlamaForCausalLM.from_pretrained(
1112
+ model_args.model_name_or_path,
1113
+ cache_dir=training_args.cache_dir,
1114
+ attn_implementation=attn_implementation,
1115
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1116
+ **bnb_model_from_pretrained_args,
1117
+ **model_max_length_args
1118
+ )
1119
+ else:
1120
+ model = transformers.LlamaForCausalLM.from_pretrained(
1121
+ model_args.model_name_or_path,
1122
+ cache_dir=training_args.cache_dir,
1123
+ attn_implementation=attn_implementation,
1124
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
1125
+ **bnb_model_from_pretrained_args
1126
+ )
1127
+ model.config.use_cache = False
1128
+
1129
+ if model_args.freeze_backbone:
1130
+ model.model.requires_grad_(False)
1131
+
1132
+ if training_args.bits in [4, 8]:
1133
+ from peft import prepare_model_for_kbit_training
1134
+ model.config.torch_dtype = (torch.float32 if training_args.fp16 else (
1135
+ torch.bfloat16 if training_args.bf16 else torch.float32))
1136
+ model = prepare_model_for_kbit_training(
1137
+ model, use_gradient_checkpointing=training_args.gradient_checkpointing)
1138
+
1139
+ if training_args.gradient_checkpointing:
1140
+ if hasattr(model, "enable_input_require_grads"):
1141
+ model.enable_input_require_grads()
1142
+ else:
1143
+ def make_inputs_require_grad(module, input, output):
1144
+ output.requires_grad_(True)
1145
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
1146
+
1147
+ if training_args.lora_enable:
1148
+ from peft import LoraConfig, get_peft_model
1149
+ lora_config = LoraConfig(
1150
+ r=training_args.lora_r,
1151
+ lora_alpha=training_args.lora_alpha,
1152
+ target_modules=find_all_linear_names(model, training_args.lora_qv_proj_only),
1153
+ lora_dropout=training_args.lora_dropout,
1154
+ bias=training_args.lora_bias,
1155
+ task_type="CAUSAL_LM",
1156
+ )
1157
+ if training_args.bits == 16:
1158
+ if training_args.bf16:
1159
+ model.to(torch.bfloat16)
1160
+ if training_args.fp16:
1161
+ model.to(torch.float16)
1162
+ rank0_print("Adding LoRA adapters...")
1163
+ model = get_peft_model(model, lora_config)
1164
+
1165
+ if 'mpt' in model_args.model_name_or_path:
1166
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1167
+ model_args.model_name_or_path,
1168
+ cache_dir=training_args.cache_dir,
1169
+ model_max_length=training_args.model_max_length,
1170
+ padding_side="right"
1171
+ )
1172
+ else:
1173
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
1174
+ model_args.model_name_or_path,
1175
+ cache_dir=training_args.cache_dir,
1176
+ model_max_length=training_args.model_max_length,
1177
+ padding_side="right",
1178
+ use_fast=False,
1179
+ )
1180
+
1181
+ if model_args.version == "v0":
1182
+ if tokenizer.pad_token is None:
1183
+ smart_tokenizer_and_embedding_resize(
1184
+ special_tokens_dict=dict(pad_token="[PAD]"),
1185
+ tokenizer=tokenizer,
1186
+ model=model,
1187
+ )
1188
+ elif model_args.version == "v0.5":
1189
+ tokenizer.pad_token = tokenizer.unk_token
1190
+ else:
1191
+ if tokenizer.pad_token is None:
1192
+ rank0_print("Adding pad token as '<pad>'")
1193
+ smart_tokenizer_and_embedding_resize(
1194
+ special_tokens_dict=dict(pad_token="<pad>"),
1195
+ tokenizer=tokenizer,
1196
+ model=model,
1197
+ )
1198
+ if model_args.version in conversation_lib.conv_templates:
1199
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
1200
+ model_args.version]
1201
+ else:
1202
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
1203
+
1204
+ if model_args.vision_tower is not None:
1205
+ model.get_model().initialize_vision_modules(
1206
+ model_args=model_args,
1207
+ fsdp=training_args.fsdp
1208
+ )
1209
+
1210
+ vision_tower = model.get_vision_tower()
1211
+ vision_tower.to(
1212
+ dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
1213
+
1214
+ data_args.image_processor = vision_tower.image_processor
1215
+ data_args.is_multimodal = True
1216
+
1217
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
1218
+ if data_args.image_aspect_ratio == 'anyres':
1219
+ base_size = vision_tower.config.image_size
1220
+ grids = [[1, 2], [2, 1], [2, 2], [3, 1], [1, 3]]
1221
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints = [
1222
+ [g[0]*base_size, g[1]*base_size] for g in grids]
1223
+ model.config.tokenizer_padding_side = tokenizer.padding_side
1224
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
1225
+
1226
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
1227
+ if model_args.tune_mm_mlp_adapter:
1228
+ model.requires_grad_(False)
1229
+ for p in model.get_model().mm_projector.parameters():
1230
+ p.requires_grad = True
1231
+
1232
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
1233
+ if training_args.freeze_mm_mlp_adapter:
1234
+ for p in model.get_model().mm_projector.parameters():
1235
+ p.requires_grad = False
1236
+
1237
+ model.config.unfreeze_mm_vision_tower = training_args.unfreeze_mm_vision_tower
1238
+ if training_args.unfreeze_mm_vision_tower:
1239
+ lr_of_vit = training_args.mm_vision_tower_lr if training_args.mm_vision_tower_lr is not None else training_args.learning_rate
1240
+ lr_of_mlp = training_args.mm_projector_lr if training_args.mm_projector_lr is not None else training_args.learning_rate
1241
+ training_args.mm_projector_lr = lr_of_mlp
1242
+ unfreeze_vit(vision_tower)
1243
+ rank0_print(
1244
+ f'Tune the entire model! The LR of ViT is {lr_of_vit}. The LR of MLP is {lr_of_mlp}. The LR of LLM is {training_args.learning_rate}')
1245
+
1246
+ # Calculate total parameters and trainable parameters
1247
+ total_params = sum(p.numel() for p in model.get_model().parameters())
1248
+ trainable_params = sum(
1249
+ p.numel() for p in model.get_model().parameters() if p.requires_grad)
1250
+
1251
+ rank0_print(f"Total parameters: {format_bytes(total_params)}")
1252
+ rank0_print(f"Trainable parameters: {format_bytes(trainable_params)}")
1253
+
1254
+ if training_args.bits in [4, 8]:
1255
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
1256
+
1257
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
1258
+ model.config.mm_projector_lr = training_args.mm_projector_lr
1259
+ model.config.mm_vision_tower_lr = training_args.mm_vision_tower_lr
1260
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
1261
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
1262
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
1263
+ model.config.pad_token_id = tokenizer.pad_token_id
1264
+
1265
+ if training_args.bits in [4, 8]:
1266
+ from peft.tuners.lora import LoraLayer
1267
+ for name, module in model.named_modules():
1268
+ if isinstance(module, LoraLayer):
1269
+ if training_args.bf16:
1270
+ module = module.to(torch.bfloat16)
1271
+ if 'norm' in name:
1272
+ module = module.to(torch.float32)
1273
+ if 'lm_head' in name or 'embed_tokens' in name:
1274
+ if hasattr(module, 'weight'):
1275
+ if training_args.bf16 and module.weight.dtype == torch.float32:
1276
+ module = module.to(torch.bfloat16)
1277
+
1278
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
1279
+ data_args=data_args)
1280
+ trainer = LLaVATrainer(model=model,
1281
+ tokenizer=tokenizer,
1282
+ args=training_args,
1283
+ **data_module)
1284
+
1285
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1286
+ trainer.train(resume_from_checkpoint=True)
1287
+ else:
1288
+ trainer.train()
1289
+ trainer.save_state()
1290
+
1291
+ model.config.use_cache = True
1292
+
1293
+ if training_args.lora_enable:
1294
+ state_dict = get_peft_state_maybe_zero_3(
1295
+ model.named_parameters(), training_args.lora_bias
1296
+ )
1297
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1298
+ model.named_parameters()
1299
+ )
1300
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
1301
+ model.config.save_pretrained(training_args.output_dir)
1302
+ model.save_pretrained(
1303
+ training_args.output_dir, state_dict=state_dict)
1304
+ torch.save(non_lora_state_dict, os.path.join(
1305
+ training_args.output_dir, 'non_lora_trainables.bin'))
1306
+ if training_args.unfreeze_mm_vision_tower:
1307
+ if trainer.deepspeed:
1308
+ torch.cuda.synchronize()
1309
+ trainer.model.get_vision_tower().image_processor.save_pretrained(
1310
+ os.path.join(training_args.output_dir, 'vision_tower'))
1311
+ trainer.model.get_vision_tower().vision_tower.vision_model.config.save_pretrained(
1312
+ os.path.join(training_args.output_dir, 'vision_tower'))
1313
+ weight_to_save = get_vision_tower_state_maybe_zero_3(
1314
+ trainer.model.get_vision_tower().vision_tower.named_parameters())
1315
+ torch.save(weight_to_save, os.path.join(
1316
+ training_args.output_dir, 'vision_tower/pytorch_model.bin'))
1317
+ else:
1318
+ safe_save_model_for_hf_trainer(trainer=trainer,
1319
+ output_dir=training_args.output_dir)
1320
+
1321
+
1322
+ if __name__ == "__main__":
1323
+ train()
llava/train/train_mem.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from llava.train.train import train
2
+
3
+ if __name__ == "__main__":
4
+ train(attn_implementation="flash_attention_2")
llava/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from llava.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
llava/video_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import json
4
+ import os
5
+ import random
6
+ import tempfile
7
+ from multiprocessing import Manager, Pool, cpu_count
8
+
9
+ import cv2
10
+ import imageio
11
+ import numpy as np
12
+ from decord import VideoReader
13
+ from PIL import Image
14
+
15
+
16
+ def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
17
+ if sample in ["rand", "middle"]: # uniform sampling
18
+ acc_samples = min(num_frames, vlen)
19
+ # split the video into `acc_samples` intervals, and sample from each interval.
20
+ intervals = np.linspace(
21
+ start=0, stop=vlen, num=acc_samples + 1).astype(int)
22
+ ranges = []
23
+ for idx, interv in enumerate(intervals[:-1]):
24
+ ranges.append((interv, intervals[idx + 1] - 1))
25
+ if sample == 'rand':
26
+ try:
27
+ frame_indices = [random.choice(
28
+ range(x[0], x[1])) for x in ranges]
29
+ except Exception:
30
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
31
+ frame_indices.sort()
32
+ frame_indices = list(frame_indices)
33
+ elif fix_start is not None:
34
+ frame_indices = [x[0] + fix_start for x in ranges]
35
+ elif sample == 'middle':
36
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
37
+ else:
38
+ raise NotImplementedError
39
+
40
+ if len(frame_indices) < num_frames: # padded with last frame
41
+ padded_frame_indices = [frame_indices[-1]] * num_frames
42
+ padded_frame_indices[:len(frame_indices)] = frame_indices
43
+ frame_indices = padded_frame_indices
44
+ elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
45
+ output_fps = float(sample[3:])
46
+ duration = float(vlen) / input_fps
47
+ # gap between frames, this is also the clip length each frame represents
48
+ delta = 1 / output_fps
49
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
50
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
51
+ frame_indices = [e for e in frame_indices if e < vlen]
52
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
53
+ frame_indices = frame_indices[:max_num_frames]
54
+ else:
55
+ raise ValueError
56
+ return frame_indices
57
+
58
+
59
+ def get_index(num_frames, bound, fps, max_frame, first_idx=0):
60
+ if bound:
61
+ start, end = bound[0], bound[1]
62
+ else:
63
+ start, end = -100000, 100000
64
+ start_idx = max(first_idx, round(start * fps))
65
+ end_idx = min(round(end * fps), max_frame)
66
+ seg_size = float(end_idx - start_idx) / num_frames
67
+ frame_indices = np.array([
68
+ int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
69
+ for idx in range(num_frames)
70
+ ])
71
+ return frame_indices
72
+
73
+
74
+ def read_frames_gif(
75
+ video_path, num_frames, sample='rand', fix_start=None,
76
+ max_num_frames=-1, client=None, clip=None,
77
+ ):
78
+ if video_path.startswith('s3') or video_path.startswith('p2'):
79
+ video_bytes = client.get(video_path)
80
+ gif = imageio.get_reader(io.BytesIO(video_bytes))
81
+ else:
82
+ gif = imageio.get_reader(video_path)
83
+ vlen = len(gif)
84
+ frame_indices = get_frame_indices(
85
+ num_frames, vlen, sample=sample, fix_start=fix_start,
86
+ max_num_frames=max_num_frames
87
+ )
88
+ frames = []
89
+ reference_size = None
90
+ for index, frame in enumerate(gif):
91
+ # for index in frame_idxs:
92
+ if index in frame_indices:
93
+ if frame.ndim == 2:
94
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
95
+ elif frame.shape[2] == 4:
96
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
97
+ if reference_size is None:
98
+ reference_size = (frame.shape[1], frame.shape[0])
99
+ frame = cv2.resize(frame, reference_size,
100
+ interpolation=cv2.INTER_LINEAR)
101
+ frames.append(frame)
102
+ frames = np.stack(frames, axis=0) # .float() / 255
103
+
104
+ return frames
105
+
106
+
107
+ def read_frames_decord(
108
+ video_path, num_frames, sample='rand', fix_start=None,
109
+ max_num_frames=-1, client=None, clip=None
110
+ ):
111
+ if video_path.startswith('s3') or video_path.startswith('p2') or video_path.startswith('p_hdd') or video_path.startswith('cluster1'):
112
+ video_bytes = client.get(video_path)
113
+ video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
114
+ else:
115
+ video_reader = VideoReader(video_path, num_threads=1)
116
+ vlen = len(video_reader)
117
+ fps = video_reader.get_avg_fps()
118
+ duration = vlen / float(fps)
119
+
120
+ if clip:
121
+ vlen = int(duration * fps)
122
+ frame_indices = get_index(num_frames, clip, fps, vlen)
123
+ else:
124
+ frame_indices = get_frame_indices(
125
+ num_frames, vlen, sample=sample, fix_start=fix_start,
126
+ input_fps=fps, max_num_frames=max_num_frames
127
+ )
128
+ # if clip:
129
+ # frame_indices = [f + start_index for f in frame_indices]
130
+
131
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C)
132
+ return frames
133
+
134
+
135
+ def read_diff_frames_decord(
136
+ video_path, clip, client=None
137
+ ):
138
+ if video_path.startswith('s3') or video_path.startswith('p2') or video_path.startswith('p_hdd') or video_path.startswith('cluster1') or video_path.startswith('s_hdd'):
139
+ video_bytes = client.get(video_path)
140
+ video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
141
+ else:
142
+ video_reader = VideoReader(video_path, num_threads=1)
143
+ vlen = len(video_reader)
144
+ fps = video_reader.get_avg_fps()
145
+
146
+ start_idx = round(clip[0]*fps)
147
+ end_idx = min(round(clip[1]*fps), vlen)
148
+ frame_indices = [start_idx, end_idx]
149
+
150
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C)
151
+ return frames
pyproject.toml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "llava"
7
+ version = "1.2.2.post1"
8
+ description = "Towards GPT-4 like large language and visual assistant."
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.1.2", "torchvision==0.16.2",
17
+ "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate==0.21.0", "peft", "bitsandbytes",
19
+ "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20
+ "gradio==4.16.0", "gradio_client==0.8.1", "openai", "spaces",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi", "decord",
22
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ train = ["deepspeed==0.12.6", "ninja", "wandb"]
27
+ build = ["build", "twine"]
28
+
29
+ [project.urls]
30
+ "Homepage" = "https://llava-vl.github.io"
31
+ "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
32
+
33
+ [tool.setuptools.packages.find]
34
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
35
+
36
+ [tool.wheel]
37
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]