Upload extra needed files
Browse filesYou can download gradio_helper.py and ov_qwen2_vl.py to avoid some error such as GBK Error. And qwen2vl.ipynb is the main file that load the model and establish the interface.
- gradio_helper.py +205 -0
- notebook_utils.py +715 -0
- ov_qwen2_vl.py +792 -0
- qwen2-build.py +43 -0
- qwen2vl.ipynb +0 -0
gradio_helper.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import copy
|
3 |
+
import re
|
4 |
+
from threading import Thread
|
5 |
+
from transformers import TextIteratorStreamer
|
6 |
+
from qwen_vl_utils import process_vision_info
|
7 |
+
|
8 |
+
|
9 |
+
def _parse_text(text):
|
10 |
+
lines = text.split("\n")
|
11 |
+
lines = [line for line in lines if line != ""]
|
12 |
+
count = 0
|
13 |
+
for i, line in enumerate(lines):
|
14 |
+
if "```" in line:
|
15 |
+
count += 1
|
16 |
+
items = line.split("`")
|
17 |
+
if count % 2 == 1:
|
18 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
19 |
+
else:
|
20 |
+
lines[i] = "<br></code></pre>"
|
21 |
+
else:
|
22 |
+
if i > 0:
|
23 |
+
if count % 2 == 1:
|
24 |
+
line = line.replace("`", r"\`")
|
25 |
+
line = line.replace("<", "<")
|
26 |
+
line = line.replace(">", ">")
|
27 |
+
line = line.replace(" ", " ")
|
28 |
+
line = line.replace("*", "*")
|
29 |
+
line = line.replace("_", "_")
|
30 |
+
line = line.replace("-", "-")
|
31 |
+
line = line.replace(".", ".")
|
32 |
+
line = line.replace("!", "!")
|
33 |
+
line = line.replace("(", "(")
|
34 |
+
line = line.replace(")", ")")
|
35 |
+
line = line.replace("$", "$")
|
36 |
+
lines[i] = "<br>" + line
|
37 |
+
text = "".join(lines)
|
38 |
+
return text
|
39 |
+
|
40 |
+
|
41 |
+
def _remove_image_special(text):
|
42 |
+
text = text.replace("<ref>", "").replace("</ref>", "")
|
43 |
+
return re.sub(r"<box>.*?(</box>|$)", "", text)
|
44 |
+
|
45 |
+
|
46 |
+
def is_video_file(filename):
|
47 |
+
video_extensions = [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg"]
|
48 |
+
return any(filename.lower().endswith(ext) for ext in video_extensions)
|
49 |
+
|
50 |
+
|
51 |
+
def transform_messages(original_messages):
|
52 |
+
transformed_messages = []
|
53 |
+
for message in original_messages:
|
54 |
+
new_content = []
|
55 |
+
for item in message["content"]:
|
56 |
+
if "image" in item:
|
57 |
+
new_item = {"type": "image", "image": item["image"]}
|
58 |
+
elif "text" in item:
|
59 |
+
new_item = {"type": "text", "text": item["text"]}
|
60 |
+
elif "video" in item:
|
61 |
+
new_item = {"type": "video", "video": item["video"]}
|
62 |
+
else:
|
63 |
+
continue
|
64 |
+
new_content.append(new_item)
|
65 |
+
|
66 |
+
new_message = {"role": message["role"], "content": new_content}
|
67 |
+
transformed_messages.append(new_message)
|
68 |
+
|
69 |
+
return transformed_messages
|
70 |
+
|
71 |
+
|
72 |
+
def make_demo(model, processor):
|
73 |
+
def call_local_model(model, processor, messages):
|
74 |
+
messages = transform_messages(messages)
|
75 |
+
|
76 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
77 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
78 |
+
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt").to(model.device)
|
79 |
+
|
80 |
+
tokenizer = processor.tokenizer
|
81 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
82 |
+
|
83 |
+
gen_kwargs = {"max_new_tokens": 512, "streamer": streamer, **inputs}
|
84 |
+
|
85 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
86 |
+
thread.start()
|
87 |
+
|
88 |
+
generated_text = ""
|
89 |
+
for new_text in streamer:
|
90 |
+
generated_text += new_text
|
91 |
+
yield generated_text
|
92 |
+
|
93 |
+
def create_predict_fn():
|
94 |
+
def predict(_chatbot, task_history):
|
95 |
+
chat_query = _chatbot[-1][0]
|
96 |
+
query = task_history[-1][0]
|
97 |
+
if len(chat_query) == 0:
|
98 |
+
_chatbot.pop()
|
99 |
+
task_history.pop()
|
100 |
+
return _chatbot
|
101 |
+
print("User: " + _parse_text(query))
|
102 |
+
history_cp = copy.deepcopy(task_history)
|
103 |
+
full_response = ""
|
104 |
+
messages = []
|
105 |
+
content = []
|
106 |
+
for q, a in history_cp:
|
107 |
+
if isinstance(q, (tuple, list)):
|
108 |
+
if is_video_file(q[0]):
|
109 |
+
content.append({"video": f"file://{q[0]}"})
|
110 |
+
else:
|
111 |
+
content.append({"image": f"file://{q[0]}"})
|
112 |
+
else:
|
113 |
+
content.append({"text": q})
|
114 |
+
messages.append({"role": "user", "content": content})
|
115 |
+
messages.append({"role": "assistant", "content": [{"text": a}]})
|
116 |
+
content = []
|
117 |
+
messages.pop()
|
118 |
+
|
119 |
+
for response in call_local_model(model, processor, messages):
|
120 |
+
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
|
121 |
+
|
122 |
+
yield _chatbot
|
123 |
+
full_response = _parse_text(response)
|
124 |
+
|
125 |
+
task_history[-1] = (query, full_response)
|
126 |
+
print("Qwen-VL-Chat: " + _parse_text(full_response))
|
127 |
+
yield _chatbot
|
128 |
+
|
129 |
+
return predict
|
130 |
+
|
131 |
+
def create_regenerate_fn():
|
132 |
+
def regenerate(_chatbot, task_history):
|
133 |
+
if not task_history:
|
134 |
+
return _chatbot
|
135 |
+
item = task_history[-1]
|
136 |
+
if item[1] is None:
|
137 |
+
return _chatbot
|
138 |
+
task_history[-1] = (item[0], None)
|
139 |
+
chatbot_item = _chatbot.pop(-1)
|
140 |
+
if chatbot_item[0] is None:
|
141 |
+
_chatbot[-1] = (_chatbot[-1][0], None)
|
142 |
+
else:
|
143 |
+
_chatbot.append((chatbot_item[0], None))
|
144 |
+
_chatbot_gen = predict(_chatbot, task_history)
|
145 |
+
for _chatbot in _chatbot_gen:
|
146 |
+
yield _chatbot
|
147 |
+
|
148 |
+
return regenerate
|
149 |
+
|
150 |
+
predict = create_predict_fn()
|
151 |
+
regenerate = create_regenerate_fn()
|
152 |
+
|
153 |
+
def add_text(history, task_history, text):
|
154 |
+
task_text = text
|
155 |
+
history = history if history is not None else []
|
156 |
+
task_history = task_history if task_history is not None else []
|
157 |
+
history = history + [(_parse_text(text), None)]
|
158 |
+
task_history = task_history + [(task_text, None)]
|
159 |
+
return history, task_history, ""
|
160 |
+
|
161 |
+
def add_file(history, task_history, file):
|
162 |
+
history = history if history is not None else []
|
163 |
+
task_history = task_history if task_history is not None else []
|
164 |
+
history = history + [((file.name,), None)]
|
165 |
+
task_history = task_history + [((file.name,), None)]
|
166 |
+
return history, task_history
|
167 |
+
|
168 |
+
def reset_user_input():
|
169 |
+
return gr.update(value="")
|
170 |
+
|
171 |
+
def reset_state(task_history):
|
172 |
+
task_history.clear()
|
173 |
+
return []
|
174 |
+
|
175 |
+
with gr.Blocks() as demo:
|
176 |
+
gr.Markdown("""<center><font size=8>Qwen2-VL OpenVINO demo</center>""")
|
177 |
+
|
178 |
+
chatbot = gr.Chatbot(label="Qwen2-VL", elem_classes="control-height", height=500)
|
179 |
+
query = gr.Textbox(lines=2, label="Input")
|
180 |
+
task_history = gr.State([])
|
181 |
+
|
182 |
+
with gr.Row():
|
183 |
+
addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image", "video"])
|
184 |
+
submit_btn = gr.Button("🚀 Submit (发送)")
|
185 |
+
regen_btn = gr.Button("🤔️ Regenerate (重试)")
|
186 |
+
empty_bin = gr.Button("🧹 Clear History (清除历史)")
|
187 |
+
|
188 |
+
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
|
189 |
+
predict, [chatbot, task_history], [chatbot], show_progress=True
|
190 |
+
)
|
191 |
+
submit_btn.click(reset_user_input, [], [query])
|
192 |
+
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
|
193 |
+
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
|
194 |
+
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
|
195 |
+
|
196 |
+
gr.Markdown(
|
197 |
+
"""\
|
198 |
+
<font size=2>Note: This demo is governed by the original license of Qwen2-VL. \
|
199 |
+
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
|
200 |
+
including hate speech, violence, pornography, deception, etc. \
|
201 |
+
(注:本演示受Qwen2-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
|
202 |
+
包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)"""
|
203 |
+
)
|
204 |
+
|
205 |
+
return demo
|
notebook_utils.py
ADDED
@@ -0,0 +1,715 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# In[ ]:
|
5 |
+
|
6 |
+
|
7 |
+
import os
|
8 |
+
import platform
|
9 |
+
import sys
|
10 |
+
import threading
|
11 |
+
import time
|
12 |
+
import urllib.parse
|
13 |
+
from os import PathLike
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import List, NamedTuple, Optional, Tuple
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
from openvino.runtime import Core, Type, get_version
|
19 |
+
from IPython.display import HTML, Image, display
|
20 |
+
|
21 |
+
import openvino as ov
|
22 |
+
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher
|
23 |
+
from openvino.runtime import opset10 as ops
|
24 |
+
|
25 |
+
|
26 |
+
# ## Files
|
27 |
+
#
|
28 |
+
# Load an image, download a file, download an IR model, and create a progress bar to show download progress.
|
29 |
+
|
30 |
+
# In[ ]:
|
31 |
+
|
32 |
+
|
33 |
+
def device_widget(default="AUTO", exclude=None, added=None):
|
34 |
+
import openvino as ov
|
35 |
+
import ipywidgets as widgets
|
36 |
+
|
37 |
+
core = ov.Core()
|
38 |
+
|
39 |
+
supported_devices = core.available_devices + ["AUTO"]
|
40 |
+
exclude = exclude or []
|
41 |
+
if exclude:
|
42 |
+
for ex_device in exclude:
|
43 |
+
if ex_device in supported_devices:
|
44 |
+
supported_devices.remove(ex_device)
|
45 |
+
|
46 |
+
added = added or []
|
47 |
+
if added:
|
48 |
+
for add_device in added:
|
49 |
+
if add_device not in supported_devices:
|
50 |
+
supported_devices.append(add_device)
|
51 |
+
|
52 |
+
device = widgets.Dropdown(
|
53 |
+
options=supported_devices,
|
54 |
+
value=default,
|
55 |
+
description="Device:",
|
56 |
+
disabled=False,
|
57 |
+
)
|
58 |
+
return device
|
59 |
+
|
60 |
+
|
61 |
+
def quantization_widget(default=True):
|
62 |
+
import ipywidgets as widgets
|
63 |
+
|
64 |
+
to_quantize = widgets.Checkbox(
|
65 |
+
value=default,
|
66 |
+
description="Quantization",
|
67 |
+
disabled=False,
|
68 |
+
)
|
69 |
+
|
70 |
+
return to_quantize
|
71 |
+
|
72 |
+
|
73 |
+
def pip_install(*args):
|
74 |
+
import subprocess # nosec - disable B404:import-subprocess check
|
75 |
+
|
76 |
+
cli_args = []
|
77 |
+
for arg in args:
|
78 |
+
cli_args.extend(str(arg).split(" "))
|
79 |
+
subprocess.run([sys.executable, "-m", "pip", "install", *cli_args], shell=(platform.system() == "Windows"), check=True)
|
80 |
+
|
81 |
+
|
82 |
+
def load_image(path: str) -> np.ndarray:
|
83 |
+
"""
|
84 |
+
Loads an image from `path` and returns it as BGR numpy array. `path`
|
85 |
+
should point to an image file, either a local filename or a url. The image is
|
86 |
+
not stored to the filesystem. Use the `download_file` function to download and
|
87 |
+
store an image.
|
88 |
+
|
89 |
+
:param path: Local path name or URL to image.
|
90 |
+
:return: image as BGR numpy array
|
91 |
+
"""
|
92 |
+
import cv2
|
93 |
+
import requests
|
94 |
+
|
95 |
+
if path.startswith("http"):
|
96 |
+
# Set User-Agent to Mozilla because some websites block
|
97 |
+
# requests with User-Agent Python
|
98 |
+
response = requests.get(path, headers={"User-Agent": "Mozilla/5.0"})
|
99 |
+
array = np.asarray(bytearray(response.content), dtype="uint8")
|
100 |
+
image = cv2.imdecode(array, -1) # Loads the image as BGR
|
101 |
+
else:
|
102 |
+
image = cv2.imread(path)
|
103 |
+
return image
|
104 |
+
|
105 |
+
|
106 |
+
def download_file(
|
107 |
+
url: PathLike,
|
108 |
+
filename: PathLike = None,
|
109 |
+
directory: PathLike = None,
|
110 |
+
show_progress: bool = True,
|
111 |
+
silent: bool = False,
|
112 |
+
timeout: int = 10,
|
113 |
+
) -> PathLike:
|
114 |
+
"""
|
115 |
+
Download a file from a url and save it to the local filesystem. The file is saved to the
|
116 |
+
current directory by default, or to `directory` if specified. If a filename is not given,
|
117 |
+
the filename of the URL will be used.
|
118 |
+
|
119 |
+
:param url: URL that points to the file to download
|
120 |
+
:param filename: Name of the local file to save. Should point to the name of the file only,
|
121 |
+
not the full path. If None the filename from the url will be used
|
122 |
+
:param directory: Directory to save the file to. Will be created if it doesn't exist
|
123 |
+
If None the file will be saved to the current working directory
|
124 |
+
:param show_progress: If True, show an TQDM ProgressBar
|
125 |
+
:param silent: If True, do not print a message if the file already exists
|
126 |
+
:param timeout: Number of seconds before cancelling the connection attempt
|
127 |
+
:return: path to downloaded file
|
128 |
+
"""
|
129 |
+
from tqdm.notebook import tqdm_notebook
|
130 |
+
import requests
|
131 |
+
|
132 |
+
filename = filename or Path(urllib.parse.urlparse(url).path).name
|
133 |
+
chunk_size = 16384 # make chunks bigger so that not too many updates are triggered for Jupyter front-end
|
134 |
+
|
135 |
+
filename = Path(filename)
|
136 |
+
if len(filename.parts) > 1:
|
137 |
+
raise ValueError(
|
138 |
+
"`filename` should refer to the name of the file, excluding the directory. "
|
139 |
+
"Use the `directory` parameter to specify a target directory for the downloaded file."
|
140 |
+
)
|
141 |
+
|
142 |
+
# create the directory if it does not exist, and add the directory to the filename
|
143 |
+
if directory is not None:
|
144 |
+
directory = Path(directory)
|
145 |
+
directory.mkdir(parents=True, exist_ok=True)
|
146 |
+
filename = directory / Path(filename)
|
147 |
+
|
148 |
+
try:
|
149 |
+
response = requests.get(url=url, headers={"User-agent": "Mozilla/5.0"}, stream=True)
|
150 |
+
response.raise_for_status()
|
151 |
+
except (
|
152 |
+
requests.exceptions.HTTPError
|
153 |
+
) as error: # For error associated with not-200 codes. Will output something like: "404 Client Error: Not Found for url: {url}"
|
154 |
+
raise Exception(error) from None
|
155 |
+
except requests.exceptions.Timeout:
|
156 |
+
raise Exception(
|
157 |
+
"Connection timed out. If you access the internet through a proxy server, please "
|
158 |
+
"make sure the proxy is set in the shell from where you launched Jupyter."
|
159 |
+
) from None
|
160 |
+
except requests.exceptions.RequestException as error:
|
161 |
+
raise Exception(f"File downloading failed with error: {error}") from None
|
162 |
+
|
163 |
+
# download the file if it does not exist, or if it exists with an incorrect file size
|
164 |
+
filesize = int(response.headers.get("Content-length", 0))
|
165 |
+
if not filename.exists() or (os.stat(filename).st_size != filesize):
|
166 |
+
with tqdm_notebook(
|
167 |
+
total=filesize,
|
168 |
+
unit="B",
|
169 |
+
unit_scale=True,
|
170 |
+
unit_divisor=1024,
|
171 |
+
desc=str(filename),
|
172 |
+
disable=not show_progress,
|
173 |
+
) as progress_bar:
|
174 |
+
with open(filename, "wb") as file_object:
|
175 |
+
for chunk in response.iter_content(chunk_size):
|
176 |
+
file_object.write(chunk)
|
177 |
+
progress_bar.update(len(chunk))
|
178 |
+
progress_bar.refresh()
|
179 |
+
else:
|
180 |
+
if not silent:
|
181 |
+
print(f"'{filename}' already exists.")
|
182 |
+
|
183 |
+
response.close()
|
184 |
+
|
185 |
+
return filename.resolve()
|
186 |
+
|
187 |
+
|
188 |
+
def download_ir_model(model_xml_url: str, destination_folder: PathLike = None) -> PathLike:
|
189 |
+
"""
|
190 |
+
Download IR model from `model_xml_url`. Downloads model xml and bin file; the weights file is
|
191 |
+
assumed to exist at the same location and name as model_xml_url with a ".bin" extension.
|
192 |
+
|
193 |
+
:param model_xml_url: URL to model xml file to download
|
194 |
+
:param destination_folder: Directory where downloaded model xml and bin are saved. If None, model
|
195 |
+
files are saved to the current directory
|
196 |
+
:return: path to downloaded xml model file
|
197 |
+
"""
|
198 |
+
model_bin_url = model_xml_url[:-4] + ".bin"
|
199 |
+
model_xml_path = download_file(model_xml_url, directory=destination_folder, show_progress=False)
|
200 |
+
download_file(model_bin_url, directory=destination_folder)
|
201 |
+
return model_xml_path
|
202 |
+
|
203 |
+
|
204 |
+
# ## Images
|
205 |
+
|
206 |
+
# ### Convert Pixel Data
|
207 |
+
#
|
208 |
+
# Normalize image pixel values between 0 and 1, and convert images to RGB and BGR.
|
209 |
+
|
210 |
+
# In[ ]:
|
211 |
+
|
212 |
+
|
213 |
+
def normalize_minmax(data):
|
214 |
+
"""
|
215 |
+
Normalizes the values in `data` between 0 and 1
|
216 |
+
"""
|
217 |
+
if data.max() == data.min():
|
218 |
+
raise ValueError("Normalization is not possible because all elements of" f"`data` have the same value: {data.max()}.")
|
219 |
+
return (data - data.min()) / (data.max() - data.min())
|
220 |
+
|
221 |
+
|
222 |
+
def to_rgb(image_data: np.ndarray) -> np.ndarray:
|
223 |
+
"""
|
224 |
+
Convert image_data from BGR to RGB
|
225 |
+
"""
|
226 |
+
import cv2
|
227 |
+
|
228 |
+
return cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
|
229 |
+
|
230 |
+
|
231 |
+
def to_bgr(image_data: np.ndarray) -> np.ndarray:
|
232 |
+
"""
|
233 |
+
Convert image_data from RGB to BGR
|
234 |
+
"""
|
235 |
+
import cv2
|
236 |
+
|
237 |
+
return cv2.cvtColor(image_data, cv2.COLOR_RGB2BGR)
|
238 |
+
|
239 |
+
|
240 |
+
# ## Videos
|
241 |
+
|
242 |
+
# ### Video Player
|
243 |
+
#
|
244 |
+
# Custom video player to fulfill FPS requirements. You can set target FPS and output size, flip the video horizontally or skip first N frames.
|
245 |
+
|
246 |
+
# In[ ]:
|
247 |
+
|
248 |
+
|
249 |
+
class VideoPlayer:
|
250 |
+
"""
|
251 |
+
Custom video player to fulfill FPS requirements. You can set target FPS and output size,
|
252 |
+
flip the video horizontally or skip first N frames.
|
253 |
+
|
254 |
+
:param source: Video source. It could be either camera device or video file.
|
255 |
+
:param size: Output frame size.
|
256 |
+
:param flip: Flip source horizontally.
|
257 |
+
:param fps: Target FPS.
|
258 |
+
:param skip_first_frames: Skip first N frames.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(self, source, size=None, flip=False, fps=None, skip_first_frames=0, width=1280, height=720):
|
262 |
+
import cv2
|
263 |
+
|
264 |
+
self.cv2 = cv2 # This is done to access the package in class methods
|
265 |
+
self.__cap = cv2.VideoCapture(source)
|
266 |
+
# try HD by default to get better video quality
|
267 |
+
self.__cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
268 |
+
self.__cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
269 |
+
|
270 |
+
if not self.__cap.isOpened():
|
271 |
+
raise RuntimeError(f"Cannot open {'camera' if isinstance(source, int) else ''} {source}")
|
272 |
+
# skip first N frames
|
273 |
+
self.__cap.set(cv2.CAP_PROP_POS_FRAMES, skip_first_frames)
|
274 |
+
# fps of input file
|
275 |
+
self.__input_fps = self.__cap.get(cv2.CAP_PROP_FPS)
|
276 |
+
if self.__input_fps <= 0:
|
277 |
+
self.__input_fps = 60
|
278 |
+
# target fps given by user
|
279 |
+
self.__output_fps = fps if fps is not None else self.__input_fps
|
280 |
+
self.__flip = flip
|
281 |
+
self.__size = None
|
282 |
+
self.__interpolation = None
|
283 |
+
if size is not None:
|
284 |
+
self.__size = size
|
285 |
+
# AREA better for shrinking, LINEAR better for enlarging
|
286 |
+
self.__interpolation = cv2.INTER_AREA if size[0] < self.__cap.get(cv2.CAP_PROP_FRAME_WIDTH) else cv2.INTER_LINEAR
|
287 |
+
# first frame
|
288 |
+
_, self.__frame = self.__cap.read()
|
289 |
+
self.__lock = threading.Lock()
|
290 |
+
self.__thread = None
|
291 |
+
self.__stop = False
|
292 |
+
|
293 |
+
"""
|
294 |
+
Start playing.
|
295 |
+
"""
|
296 |
+
|
297 |
+
def start(self):
|
298 |
+
self.__stop = False
|
299 |
+
self.__thread = threading.Thread(target=self.__run, daemon=True)
|
300 |
+
self.__thread.start()
|
301 |
+
|
302 |
+
"""
|
303 |
+
Stop playing and release resources.
|
304 |
+
"""
|
305 |
+
|
306 |
+
def stop(self):
|
307 |
+
self.__stop = True
|
308 |
+
if self.__thread is not None:
|
309 |
+
self.__thread.join()
|
310 |
+
self.__cap.release()
|
311 |
+
|
312 |
+
def __run(self):
|
313 |
+
prev_time = 0
|
314 |
+
while not self.__stop:
|
315 |
+
t1 = time.time()
|
316 |
+
ret, frame = self.__cap.read()
|
317 |
+
if not ret:
|
318 |
+
break
|
319 |
+
|
320 |
+
# fulfill target fps
|
321 |
+
if 1 / self.__output_fps < time.time() - prev_time:
|
322 |
+
prev_time = time.time()
|
323 |
+
# replace by current frame
|
324 |
+
with self.__lock:
|
325 |
+
self.__frame = frame
|
326 |
+
|
327 |
+
t2 = time.time()
|
328 |
+
# time to wait [s] to fulfill input fps
|
329 |
+
wait_time = 1 / self.__input_fps - (t2 - t1)
|
330 |
+
# wait until
|
331 |
+
time.sleep(max(0, wait_time))
|
332 |
+
|
333 |
+
self.__frame = None
|
334 |
+
|
335 |
+
"""
|
336 |
+
Get current frame.
|
337 |
+
"""
|
338 |
+
|
339 |
+
def next(self):
|
340 |
+
import cv2
|
341 |
+
|
342 |
+
with self.__lock:
|
343 |
+
if self.__frame is None:
|
344 |
+
return None
|
345 |
+
# need to copy frame, because can be cached and reused if fps is low
|
346 |
+
frame = self.__frame.copy()
|
347 |
+
if self.__size is not None:
|
348 |
+
frame = self.cv2.resize(frame, self.__size, interpolation=self.__interpolation)
|
349 |
+
if self.__flip:
|
350 |
+
frame = self.cv2.flip(frame, 1)
|
351 |
+
return frame
|
352 |
+
|
353 |
+
|
354 |
+
# ## Visualization
|
355 |
+
|
356 |
+
# ### Segmentation
|
357 |
+
#
|
358 |
+
# Define a SegmentationMap NamedTuple that keeps the labels and colormap for a segmentation project/dataset. Create CityScapesSegmentation and BinarySegmentation SegmentationMaps. Create a function to convert a segmentation map to an RGB image with a colormap, and to show the segmentation result as an overlay over the original image.
|
359 |
+
|
360 |
+
# In[ ]:
|
361 |
+
|
362 |
+
|
363 |
+
class Label(NamedTuple):
|
364 |
+
index: int
|
365 |
+
color: Tuple
|
366 |
+
name: Optional[str] = None
|
367 |
+
|
368 |
+
|
369 |
+
# In[ ]:
|
370 |
+
|
371 |
+
|
372 |
+
class SegmentationMap(NamedTuple):
|
373 |
+
labels: List
|
374 |
+
|
375 |
+
def get_colormap(self):
|
376 |
+
return np.array([label.color for label in self.labels])
|
377 |
+
|
378 |
+
def get_labels(self):
|
379 |
+
labelnames = [label.name for label in self.labels]
|
380 |
+
if any(labelnames):
|
381 |
+
return labelnames
|
382 |
+
else:
|
383 |
+
return None
|
384 |
+
|
385 |
+
|
386 |
+
# In[ ]:
|
387 |
+
|
388 |
+
|
389 |
+
cityscape_labels = [
|
390 |
+
Label(index=0, color=(128, 64, 128), name="road"),
|
391 |
+
Label(index=1, color=(244, 35, 232), name="sidewalk"),
|
392 |
+
Label(index=2, color=(70, 70, 70), name="building"),
|
393 |
+
Label(index=3, color=(102, 102, 156), name="wall"),
|
394 |
+
Label(index=4, color=(190, 153, 153), name="fence"),
|
395 |
+
Label(index=5, color=(153, 153, 153), name="pole"),
|
396 |
+
Label(index=6, color=(250, 170, 30), name="traffic light"),
|
397 |
+
Label(index=7, color=(220, 220, 0), name="traffic sign"),
|
398 |
+
Label(index=8, color=(107, 142, 35), name="vegetation"),
|
399 |
+
Label(index=9, color=(152, 251, 152), name="terrain"),
|
400 |
+
Label(index=10, color=(70, 130, 180), name="sky"),
|
401 |
+
Label(index=11, color=(220, 20, 60), name="person"),
|
402 |
+
Label(index=12, color=(255, 0, 0), name="rider"),
|
403 |
+
Label(index=13, color=(0, 0, 142), name="car"),
|
404 |
+
Label(index=14, color=(0, 0, 70), name="truck"),
|
405 |
+
Label(index=15, color=(0, 60, 100), name="bus"),
|
406 |
+
Label(index=16, color=(0, 80, 100), name="train"),
|
407 |
+
Label(index=17, color=(0, 0, 230), name="motorcycle"),
|
408 |
+
Label(index=18, color=(119, 11, 32), name="bicycle"),
|
409 |
+
Label(index=19, color=(255, 255, 255), name="background"),
|
410 |
+
]
|
411 |
+
|
412 |
+
CityScapesSegmentation = SegmentationMap(cityscape_labels)
|
413 |
+
|
414 |
+
binary_labels = [
|
415 |
+
Label(index=0, color=(255, 255, 255), name="background"),
|
416 |
+
Label(index=1, color=(0, 0, 0), name="foreground"),
|
417 |
+
]
|
418 |
+
|
419 |
+
BinarySegmentation = SegmentationMap(binary_labels)
|
420 |
+
|
421 |
+
|
422 |
+
# In[ ]:
|
423 |
+
|
424 |
+
|
425 |
+
def segmentation_map_to_image(result: np.ndarray, colormap: np.ndarray, remove_holes: bool = False) -> np.ndarray:
|
426 |
+
"""
|
427 |
+
Convert network result of floating point numbers to an RGB image with
|
428 |
+
integer values from 0-255 by applying a colormap.
|
429 |
+
|
430 |
+
:param result: A single network result after converting to pixel values in H,W or 1,H,W shape.
|
431 |
+
:param colormap: A numpy array of shape (num_classes, 3) with an RGB value per class.
|
432 |
+
:param remove_holes: If True, remove holes in the segmentation result.
|
433 |
+
:return: An RGB image where each pixel is an int8 value according to colormap.
|
434 |
+
"""
|
435 |
+
import cv2
|
436 |
+
|
437 |
+
if len(result.shape) != 2 and result.shape[0] != 1:
|
438 |
+
raise ValueError(f"Expected result with shape (H,W) or (1,H,W), got result with shape {result.shape}")
|
439 |
+
|
440 |
+
if len(np.unique(result)) > colormap.shape[0]:
|
441 |
+
raise ValueError(
|
442 |
+
f"Expected max {colormap[0]} classes in result, got {len(np.unique(result))} "
|
443 |
+
"different output values. Please make sure to convert the network output to "
|
444 |
+
"pixel values before calling this function."
|
445 |
+
)
|
446 |
+
elif result.shape[0] == 1:
|
447 |
+
result = result.squeeze(0)
|
448 |
+
|
449 |
+
result = result.astype(np.uint8)
|
450 |
+
|
451 |
+
contour_mode = cv2.RETR_EXTERNAL if remove_holes else cv2.RETR_TREE
|
452 |
+
mask = np.zeros((result.shape[0], result.shape[1], 3), dtype=np.uint8)
|
453 |
+
for label_index, color in enumerate(colormap):
|
454 |
+
label_index_map = result == label_index
|
455 |
+
label_index_map = label_index_map.astype(np.uint8) * 255
|
456 |
+
contours, hierarchies = cv2.findContours(label_index_map, contour_mode, cv2.CHAIN_APPROX_SIMPLE)
|
457 |
+
cv2.drawContours(
|
458 |
+
mask,
|
459 |
+
contours,
|
460 |
+
contourIdx=-1,
|
461 |
+
color=color.tolist(),
|
462 |
+
thickness=cv2.FILLED,
|
463 |
+
)
|
464 |
+
|
465 |
+
return mask
|
466 |
+
|
467 |
+
|
468 |
+
def segmentation_map_to_overlay(image, result, alpha, colormap, remove_holes=False) -> np.ndarray:
|
469 |
+
"""
|
470 |
+
Returns a new image where a segmentation mask (created with colormap) is overlayed on
|
471 |
+
the source image.
|
472 |
+
|
473 |
+
:param image: Source image.
|
474 |
+
:param result: A single network result after converting to pixel values in H,W or 1,H,W shape.
|
475 |
+
:param alpha: Alpha transparency value for the overlay image.
|
476 |
+
:param colormap: A numpy array of shape (num_classes, 3) with an RGB value per class.
|
477 |
+
:param remove_holes: If True, remove holes in the segmentation result.
|
478 |
+
:return: An RGP image with segmentation mask overlayed on the source image.
|
479 |
+
"""
|
480 |
+
import cv2
|
481 |
+
|
482 |
+
if len(image.shape) == 2:
|
483 |
+
image = np.repeat(np.expand_dims(image, -1), 3, 2)
|
484 |
+
mask = segmentation_map_to_image(result, colormap, remove_holes)
|
485 |
+
image_height, image_width = image.shape[:2]
|
486 |
+
mask = cv2.resize(src=mask, dsize=(image_width, image_height))
|
487 |
+
return cv2.addWeighted(mask, alpha, image, 1 - alpha, 0)
|
488 |
+
|
489 |
+
|
490 |
+
# ### Network Results
|
491 |
+
#
|
492 |
+
# Show network result image, optionally together with the source image and a legend with labels.
|
493 |
+
|
494 |
+
# In[ ]:
|
495 |
+
|
496 |
+
|
497 |
+
def viz_result_image(
|
498 |
+
result_image: np.ndarray,
|
499 |
+
source_image: np.ndarray = None,
|
500 |
+
source_title: str = None,
|
501 |
+
result_title: str = None,
|
502 |
+
labels: List[Label] = None,
|
503 |
+
resize: bool = False,
|
504 |
+
bgr_to_rgb: bool = False,
|
505 |
+
hide_axes: bool = False,
|
506 |
+
):
|
507 |
+
"""
|
508 |
+
Show result image, optionally together with source images, and a legend with labels.
|
509 |
+
|
510 |
+
:param result_image: Numpy array of RGB result image.
|
511 |
+
:param source_image: Numpy array of source image. If provided this image will be shown
|
512 |
+
next to the result image. source_image is expected to be in RGB format.
|
513 |
+
Set bgr_to_rgb to True if source_image is in BGR format.
|
514 |
+
:param source_title: Title to display for the source image.
|
515 |
+
:param result_title: Title to display for the result image.
|
516 |
+
:param labels: List of labels. If provided, a legend will be shown with the given labels.
|
517 |
+
:param resize: If true, resize the result image to the same shape as the source image.
|
518 |
+
:param bgr_to_rgb: If true, convert the source image from BGR to RGB. Use this option if
|
519 |
+
source_image is a BGR image.
|
520 |
+
:param hide_axes: If true, do not show matplotlib axes.
|
521 |
+
:return: Matplotlib figure with result image
|
522 |
+
"""
|
523 |
+
import cv2
|
524 |
+
import matplotlib.pyplot as plt
|
525 |
+
from matplotlib.lines import Line2D
|
526 |
+
|
527 |
+
if bgr_to_rgb:
|
528 |
+
source_image = to_rgb(source_image)
|
529 |
+
if resize:
|
530 |
+
result_image = cv2.resize(result_image, (source_image.shape[1], source_image.shape[0]))
|
531 |
+
|
532 |
+
num_images = 1 if source_image is None else 2
|
533 |
+
|
534 |
+
fig, ax = plt.subplots(1, num_images, figsize=(16, 8), squeeze=False)
|
535 |
+
if source_image is not None:
|
536 |
+
ax[0, 0].imshow(source_image)
|
537 |
+
ax[0, 0].set_title(source_title)
|
538 |
+
|
539 |
+
ax[0, num_images - 1].imshow(result_image)
|
540 |
+
ax[0, num_images - 1].set_title(result_title)
|
541 |
+
|
542 |
+
if hide_axes:
|
543 |
+
for a in ax.ravel():
|
544 |
+
a.axis("off")
|
545 |
+
if labels:
|
546 |
+
colors = labels.get_colormap()
|
547 |
+
lines = [
|
548 |
+
Line2D(
|
549 |
+
[0],
|
550 |
+
[0],
|
551 |
+
color=[item / 255 for item in c.tolist()],
|
552 |
+
linewidth=3,
|
553 |
+
linestyle="-",
|
554 |
+
)
|
555 |
+
for c in colors
|
556 |
+
]
|
557 |
+
plt.legend(
|
558 |
+
lines,
|
559 |
+
labels.get_labels(),
|
560 |
+
bbox_to_anchor=(1, 1),
|
561 |
+
loc="upper left",
|
562 |
+
prop={"size": 12},
|
563 |
+
)
|
564 |
+
plt.close(fig)
|
565 |
+
return fig
|
566 |
+
|
567 |
+
|
568 |
+
# ### Live Inference
|
569 |
+
|
570 |
+
# In[ ]:
|
571 |
+
|
572 |
+
|
573 |
+
def show_array(frame: np.ndarray, display_handle=None):
|
574 |
+
"""
|
575 |
+
Display array `frame`. Replace information at `display_handle` with `frame`
|
576 |
+
encoded as jpeg image. `frame` is expected to have data in BGR order.
|
577 |
+
|
578 |
+
Create a display_handle with: `display_handle = display(display_id=True)`
|
579 |
+
"""
|
580 |
+
import cv2
|
581 |
+
|
582 |
+
_, frame = cv2.imencode(ext=".jpeg", img=frame)
|
583 |
+
if display_handle is None:
|
584 |
+
display_handle = display(Image(data=frame.tobytes()), display_id=True)
|
585 |
+
else:
|
586 |
+
display_handle.update(Image(data=frame.tobytes()))
|
587 |
+
return display_handle
|
588 |
+
|
589 |
+
|
590 |
+
# ## Checks and Alerts
|
591 |
+
#
|
592 |
+
# Create an alert class to show stylized info/error/warning messages and a `check_device` function that checks whether a given device is available.
|
593 |
+
|
594 |
+
# In[ ]:
|
595 |
+
|
596 |
+
|
597 |
+
class NotebookAlert(Exception):
|
598 |
+
def __init__(self, message: str, alert_class: str):
|
599 |
+
"""
|
600 |
+
Show an alert box with the given message.
|
601 |
+
|
602 |
+
:param message: The message to display.
|
603 |
+
:param alert_class: The class for styling the message. Options: info, warning, success, danger.
|
604 |
+
"""
|
605 |
+
self.message = message
|
606 |
+
self.alert_class = alert_class
|
607 |
+
self.show_message()
|
608 |
+
|
609 |
+
def show_message(self):
|
610 |
+
display(HTML(f"""<div class="alert alert-{self.alert_class}">{self.message}"""))
|
611 |
+
|
612 |
+
|
613 |
+
class DeviceNotFoundAlert(NotebookAlert):
|
614 |
+
def __init__(self, device: str):
|
615 |
+
"""
|
616 |
+
Show a warning message about an unavailable device. This class does not check whether or
|
617 |
+
not the device is available, use the `check_device` function to check this. `check_device`
|
618 |
+
also shows the warning if the device is not found.
|
619 |
+
|
620 |
+
:param device: The unavailable device.
|
621 |
+
:return: A formatted alert box with the message that `device` is not available, and a list
|
622 |
+
of devices that are available.
|
623 |
+
"""
|
624 |
+
ie = Core()
|
625 |
+
supported_devices = ie.available_devices
|
626 |
+
self.message = f"Running this cell requires a {device} device, " "which is not available on this system. "
|
627 |
+
self.alert_class = "warning"
|
628 |
+
if len(supported_devices) == 1:
|
629 |
+
self.message += f"The following device is available: {ie.available_devices[0]}"
|
630 |
+
else:
|
631 |
+
self.message += "The following devices are available: " f"{', '.join(ie.available_devices)}"
|
632 |
+
super().__init__(self.message, self.alert_class)
|
633 |
+
|
634 |
+
|
635 |
+
def check_device(device: str) -> bool:
|
636 |
+
"""
|
637 |
+
Check if the specified device is available on the system.
|
638 |
+
|
639 |
+
:param device: Device to check. e.g. CPU, GPU
|
640 |
+
:return: True if the device is available, False if not. If the device is not available,
|
641 |
+
a DeviceNotFoundAlert will be shown.
|
642 |
+
"""
|
643 |
+
ie = Core()
|
644 |
+
if device not in ie.available_devices:
|
645 |
+
DeviceNotFoundAlert(device)
|
646 |
+
return False
|
647 |
+
else:
|
648 |
+
return True
|
649 |
+
|
650 |
+
|
651 |
+
def check_openvino_version(version: str) -> bool:
|
652 |
+
"""
|
653 |
+
Check if the specified OpenVINO version is installed.
|
654 |
+
|
655 |
+
:param version: the OpenVINO version to check. Example: 2021.4
|
656 |
+
:return: True if the version is installed, False if not. If the version is not installed,
|
657 |
+
an alert message will be shown.
|
658 |
+
"""
|
659 |
+
installed_version = get_version()
|
660 |
+
if version not in installed_version:
|
661 |
+
NotebookAlert(
|
662 |
+
f"This notebook requires OpenVINO {version}. "
|
663 |
+
f"The version on your system is: <i>{installed_version}</i>.<br>"
|
664 |
+
"Please run <span style='font-family:monospace'>pip install --upgrade -r requirements.txt</span> "
|
665 |
+
"in the openvino_env environment to install this version. "
|
666 |
+
"See the <a href='https://github.com/openvinotoolkit/openvino_notebooks'>"
|
667 |
+
"OpenVINO Notebooks README</a> for detailed instructions",
|
668 |
+
alert_class="danger",
|
669 |
+
)
|
670 |
+
return False
|
671 |
+
else:
|
672 |
+
return True
|
673 |
+
|
674 |
+
|
675 |
+
packed_layername_tensor_dict_list = [{"name": "aten::mul/Multiply"}]
|
676 |
+
|
677 |
+
|
678 |
+
class ReplaceTensor(MatcherPass):
|
679 |
+
def __init__(self, packed_layername_tensor_dict_list):
|
680 |
+
MatcherPass.__init__(self)
|
681 |
+
self.model_changed = False
|
682 |
+
|
683 |
+
param = WrapType("opset10.Multiply")
|
684 |
+
|
685 |
+
def callback(matcher: Matcher) -> bool:
|
686 |
+
root = matcher.get_match_root()
|
687 |
+
if root is None:
|
688 |
+
return False
|
689 |
+
for y in packed_layername_tensor_dict_list:
|
690 |
+
root_name = root.get_friendly_name()
|
691 |
+
if root_name.find(y["name"]) != -1:
|
692 |
+
max_fp16 = np.array([[[[-np.finfo(np.float16).max]]]]).astype(np.float32)
|
693 |
+
new_tenser = ops.constant(max_fp16, Type.f32, name="Constant_4431")
|
694 |
+
root.set_arguments([root.input_value(0).node, new_tenser])
|
695 |
+
packed_layername_tensor_dict_list.remove(y)
|
696 |
+
|
697 |
+
return True
|
698 |
+
|
699 |
+
self.register_matcher(Matcher(param, "ReplaceTensor"), callback)
|
700 |
+
|
701 |
+
|
702 |
+
def optimize_bge_embedding(model_path, output_model_path):
|
703 |
+
"""
|
704 |
+
optimize_bge_embedding used to optimize BGE model for NPU device
|
705 |
+
|
706 |
+
Arguments:
|
707 |
+
model_path {str} -- original BGE IR model path
|
708 |
+
output_model_path {str} -- Converted BGE IR model path
|
709 |
+
"""
|
710 |
+
core = Core()
|
711 |
+
ov_model = core.read_model(model_path)
|
712 |
+
manager = Manager()
|
713 |
+
manager.register_pass(ReplaceTensor(packed_layername_tensor_dict_list))
|
714 |
+
manager.run_passes(ov_model)
|
715 |
+
ov.save_model(ov_model, output_model_path, compress_to_fp16=False)
|
ov_qwen2_vl.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import types
|
3 |
+
from typing import Optional, Tuple, Union, List, Dict, Any
|
4 |
+
import gc
|
5 |
+
import openvino as ov
|
6 |
+
from openvino.runtime import opset13
|
7 |
+
import nncf
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig
|
11 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, VisionRotaryEmbedding
|
12 |
+
from transformers.cache_utils import DynamicCache
|
13 |
+
from transformers.modeling_outputs import ModelOutput
|
14 |
+
from transformers.generation import GenerationConfig, GenerationMixin
|
15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
16 |
+
|
17 |
+
model_ids = ["Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"]
|
18 |
+
|
19 |
+
|
20 |
+
def model_selector(default=model_ids[0]):
|
21 |
+
import ipywidgets as widgets
|
22 |
+
|
23 |
+
model_checkpoint = widgets.Dropdown(
|
24 |
+
options=model_ids,
|
25 |
+
default=default,
|
26 |
+
description="Model:",
|
27 |
+
)
|
28 |
+
return model_checkpoint
|
29 |
+
|
30 |
+
|
31 |
+
def model_has_state(ov_model: ov.Model):
|
32 |
+
return len(ov_model.get_sinks()) > 0
|
33 |
+
|
34 |
+
|
35 |
+
def model_has_input_output_name(ov_model: ov.Model, name: str):
|
36 |
+
"""
|
37 |
+
Helper function for checking that model has specified input or output name
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
ov_model (ov.Model):
|
41 |
+
name (str):
|
42 |
+
name of input or output
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
True if input or output with requested name exists else False
|
46 |
+
"""
|
47 |
+
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])
|
48 |
+
|
49 |
+
|
50 |
+
def fuse_cache_reorder(
|
51 |
+
ov_model: ov.Model,
|
52 |
+
not_kv_inputs: List[str],
|
53 |
+
key_value_input_names: List[str],
|
54 |
+
gather_dim: int,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
|
58 |
+
|
59 |
+
Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
|
60 |
+
Should be run before make_stateful. Implements optimumum's _reorder_cache
|
61 |
+
inside the model in the beginning of each iteration.
|
62 |
+
Gather works along given gather_dim dimension that may vary from model to model.
|
63 |
+
KV-cache inputs are identified based on names in key_value_input_names.
|
64 |
+
Append the new beam_idx parameter to not_kv_inputs.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
ov_model (`ov.Model`):
|
68 |
+
openvino model for processing
|
69 |
+
not_kv_inputs (`List[str]`):
|
70 |
+
list of input nodes in model that not related to past key values
|
71 |
+
key_value_input_names (`List[str]`):
|
72 |
+
list of names for key value input layers
|
73 |
+
gather_dim (int):
|
74 |
+
dimension for gathering cache during reorder pass
|
75 |
+
"""
|
76 |
+
|
77 |
+
if model_has_input_output_name(ov_model, "beam_idx"):
|
78 |
+
raise ValueError("Model already has fused cache")
|
79 |
+
input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0]
|
80 |
+
beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
|
81 |
+
beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted?
|
82 |
+
ov_model.add_parameters([beam_idx])
|
83 |
+
not_kv_inputs.append(ov_model.inputs[-1])
|
84 |
+
# Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx
|
85 |
+
for input_name in key_value_input_names:
|
86 |
+
parameter_output_port = ov_model.input(input_name)
|
87 |
+
consumers = parameter_output_port.get_target_inputs()
|
88 |
+
gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
|
89 |
+
for consumer in consumers:
|
90 |
+
consumer.replace_source_output(gather.output(0))
|
91 |
+
ov_model.validate_nodes_and_infer_types()
|
92 |
+
|
93 |
+
|
94 |
+
def build_state_initializer(ov_model: ov.Model, batch_dim: int):
|
95 |
+
"""
|
96 |
+
Build initialization ShapeOf Expression for all ReadValue ops
|
97 |
+
|
98 |
+
Parameters:
|
99 |
+
ov_model (ov.Model):
|
100 |
+
openvino model
|
101 |
+
batch_dim (int):
|
102 |
+
index of dimension corresponding to batch size
|
103 |
+
"""
|
104 |
+
input_ids = ov_model.input("inputs_embeds")
|
105 |
+
batch = opset13.gather(
|
106 |
+
opset13.shape_of(input_ids, output_type="i64"),
|
107 |
+
opset13.constant([0]),
|
108 |
+
opset13.constant(0),
|
109 |
+
)
|
110 |
+
for op in ov_model.get_ops():
|
111 |
+
if op.get_type_name() == "ReadValue":
|
112 |
+
dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
|
113 |
+
dims[batch_dim] = batch
|
114 |
+
dims = [(opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims]
|
115 |
+
shape = opset13.concat(dims, axis=0)
|
116 |
+
broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
|
117 |
+
op.set_arguments([broadcast])
|
118 |
+
ov_model.validate_nodes_and_infer_types()
|
119 |
+
|
120 |
+
|
121 |
+
def make_stateful(
|
122 |
+
ov_model: ov.Model,
|
123 |
+
not_kv_inputs: List[str],
|
124 |
+
key_value_input_names: List[str],
|
125 |
+
key_value_output_names: List[str],
|
126 |
+
batch_dim: int,
|
127 |
+
num_attention_heads: int,
|
128 |
+
num_beams_and_batch: int = None,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Hides kv-cache inputs and outputs inside the model as variables.
|
132 |
+
|
133 |
+
Parameters:
|
134 |
+
ov_model (ov.Model):
|
135 |
+
openvino model
|
136 |
+
not_kv_inputs (`List[str]`):
|
137 |
+
list of input nodes in model that not related to past key values
|
138 |
+
key_value_input_names (`List[str]`):
|
139 |
+
list of names for key value input layers
|
140 |
+
key_value_output_names (`List[str]`):
|
141 |
+
list of names for key value input layers
|
142 |
+
batch_dim (int):
|
143 |
+
index of batch dimension in key value layers
|
144 |
+
num_attention_heads (int):
|
145 |
+
number of attention heads for batch dimension initialization
|
146 |
+
num_beams_an_batch (int):
|
147 |
+
precalculated number of beams and batch for shapes initialization
|
148 |
+
"""
|
149 |
+
from openvino._offline_transformations import apply_make_stateful_transformation
|
150 |
+
|
151 |
+
input_output_map = {}
|
152 |
+
|
153 |
+
if num_beams_and_batch is not None:
|
154 |
+
# Set batch size for input_ids and attention mask to avoid dynamic dimension got propagated from the end of the model back to ReadValue
|
155 |
+
for input in not_kv_inputs:
|
156 |
+
shape = input.get_partial_shape()
|
157 |
+
if shape.rank.get_length() <= 2: # == 1 for beam_index
|
158 |
+
shape[0] = num_beams_and_batch
|
159 |
+
input.get_node().set_partial_shape(shape)
|
160 |
+
for kv_name_pair in zip(key_value_input_names, key_value_output_names):
|
161 |
+
input_output_map[kv_name_pair[0]] = kv_name_pair[1]
|
162 |
+
if num_beams_and_batch is not None:
|
163 |
+
input = ov_model.input(kv_name_pair[0])
|
164 |
+
shape = input.get_partial_shape()
|
165 |
+
shape[batch_dim] = num_beams_and_batch * num_attention_heads
|
166 |
+
input.get_node().set_partial_shape(shape)
|
167 |
+
|
168 |
+
if num_beams_and_batch is not None:
|
169 |
+
# Re-validation model if shapes are altered above
|
170 |
+
ov_model.validate_nodes_and_infer_types()
|
171 |
+
|
172 |
+
apply_make_stateful_transformation(ov_model, input_output_map)
|
173 |
+
if num_beams_and_batch is None:
|
174 |
+
build_state_initializer(ov_model, batch_dim)
|
175 |
+
|
176 |
+
|
177 |
+
def patch_stateful(ov_model):
|
178 |
+
key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]]
|
179 |
+
key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]]
|
180 |
+
not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())]
|
181 |
+
if not key_value_input_names or not key_value_output_names:
|
182 |
+
return
|
183 |
+
batch_dim = 0
|
184 |
+
num_attention_heads = 1
|
185 |
+
|
186 |
+
fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
|
187 |
+
make_stateful(
|
188 |
+
ov_model,
|
189 |
+
not_kv_inputs,
|
190 |
+
key_value_input_names,
|
191 |
+
key_value_output_names,
|
192 |
+
batch_dim,
|
193 |
+
num_attention_heads,
|
194 |
+
None,
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
core = ov.Core()
|
199 |
+
|
200 |
+
|
201 |
+
def cleanup_torchscript_cache():
|
202 |
+
"""
|
203 |
+
Helper for removing cached model representation
|
204 |
+
"""
|
205 |
+
torch._C._jit_clear_class_registry()
|
206 |
+
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
|
207 |
+
torch.jit._state._clear_class_state()
|
208 |
+
|
209 |
+
|
210 |
+
LANGUAGE_MODEL_NAME = "openvino_language_model.xml"
|
211 |
+
IMAGE_EMBEDDING_NAME = "openvino_vision_embeddings_model.xml"
|
212 |
+
IMAGE_EMBEDDING_MERGER_NAME = "openvino_vision_embeddings_merger_model.xml"
|
213 |
+
TEXT_EMBEDDING_NAME = "openvino_text_embeddings_model.xml"
|
214 |
+
|
215 |
+
|
216 |
+
def convert_qwen2vl_model(model_id, output_dir, quantization_config):
|
217 |
+
output_dir = Path(output_dir)
|
218 |
+
|
219 |
+
lang_model_path = output_dir / LANGUAGE_MODEL_NAME
|
220 |
+
image_embed_path = output_dir / IMAGE_EMBEDDING_NAME
|
221 |
+
embed_token_path = output_dir / TEXT_EMBEDDING_NAME
|
222 |
+
image_embed_merger_path = output_dir / IMAGE_EMBEDDING_MERGER_NAME
|
223 |
+
|
224 |
+
if all(
|
225 |
+
[
|
226 |
+
lang_model_path.exists(),
|
227 |
+
image_embed_path.exists(),
|
228 |
+
image_embed_merger_path.exists(),
|
229 |
+
embed_token_path.exists(),
|
230 |
+
]
|
231 |
+
):
|
232 |
+
print(f"✅ {model_id} model already converted. You can find results in {output_dir}")
|
233 |
+
return
|
234 |
+
print(f"⌛ {model_id} conversion started. Be patient, it may takes some time.")
|
235 |
+
print("⌛ Load Original model")
|
236 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(model_id)
|
237 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
238 |
+
model.config.save_pretrained(output_dir)
|
239 |
+
processor.save_pretrained(output_dir)
|
240 |
+
print("✅ Original model successfully loaded")
|
241 |
+
|
242 |
+
if not embed_token_path.exists():
|
243 |
+
print("⌛ Convert Input embedding model")
|
244 |
+
ov_model = ov.convert_model(
|
245 |
+
model.model.embed_tokens,
|
246 |
+
example_input=torch.ones([2, 2], dtype=torch.int64),
|
247 |
+
)
|
248 |
+
ov.save_model(ov_model, embed_token_path)
|
249 |
+
del ov_model
|
250 |
+
cleanup_torchscript_cache()
|
251 |
+
gc.collect()
|
252 |
+
print("✅ Input embedding model successfully converted")
|
253 |
+
|
254 |
+
if not image_embed_path.exists() or not image_embed_merger_path.exists():
|
255 |
+
print("⌛ Convert Image embedding model")
|
256 |
+
|
257 |
+
vision_embed_tokens = model.visual
|
258 |
+
if not image_embed_path.exists():
|
259 |
+
ov_model = ov.convert_model(vision_embed_tokens.patch_embed, example_input={"hidden_states": torch.randn([4988, 1176])})
|
260 |
+
ov.save_model(ov_model, image_embed_path)
|
261 |
+
del ov_model
|
262 |
+
cleanup_torchscript_cache()
|
263 |
+
|
264 |
+
def image_embed_forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor:
|
265 |
+
for blk in self.blocks:
|
266 |
+
hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
|
267 |
+
|
268 |
+
return self.merger(hidden_states)
|
269 |
+
|
270 |
+
def sdpa_attn_forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None) -> torch.Tensor:
|
271 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
|
272 |
+
|
273 |
+
seq_length = hidden_states.shape[0]
|
274 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
275 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
276 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
277 |
+
|
278 |
+
q = q.transpose(0, 1)
|
279 |
+
k = k.transpose(0, 1)
|
280 |
+
v = v.transpose(0, 1)
|
281 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
282 |
+
attn_output = attn_output.transpose(0, 1)
|
283 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
284 |
+
attn_output = self.proj(attn_output)
|
285 |
+
return attn_output
|
286 |
+
|
287 |
+
def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor:
|
288 |
+
hidden_states = hidden_states + self.attn(self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
|
289 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
290 |
+
return hidden_states
|
291 |
+
|
292 |
+
if not image_embed_merger_path.exists():
|
293 |
+
vision_embed_tokens.forward = types.MethodType(image_embed_forward, vision_embed_tokens)
|
294 |
+
for block in vision_embed_tokens.blocks:
|
295 |
+
block.forward = types.MethodType(block_forward, block)
|
296 |
+
block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn)
|
297 |
+
|
298 |
+
ov_model = ov.convert_model(
|
299 |
+
vision_embed_tokens,
|
300 |
+
example_input={
|
301 |
+
"hidden_states": torch.randn([4988, 1280]),
|
302 |
+
"attention_mask": torch.ones([1, 4988, 4988]),
|
303 |
+
"rotary_pos_emb": torch.randn([4988, 40]),
|
304 |
+
},
|
305 |
+
)
|
306 |
+
if quantization_config is not None:
|
307 |
+
print(f"⌛ Weights compression with {quantization_config['mode']} mode started")
|
308 |
+
ov_model = nncf.compress_weights(ov_model, **quantization_config)
|
309 |
+
print("✅ Weights compression finished")
|
310 |
+
|
311 |
+
ov.save_model(ov_model, image_embed_merger_path)
|
312 |
+
del ov_model
|
313 |
+
cleanup_torchscript_cache()
|
314 |
+
del vision_embed_tokens
|
315 |
+
gc.collect()
|
316 |
+
print("✅ Image embedding model successfully converted")
|
317 |
+
|
318 |
+
if not lang_model_path.exists():
|
319 |
+
print("⌛ Convert Language model")
|
320 |
+
|
321 |
+
def forward_wrap(
|
322 |
+
self,
|
323 |
+
attention_mask,
|
324 |
+
position_ids=None,
|
325 |
+
past_key_values=None,
|
326 |
+
inputs_embeds=None,
|
327 |
+
):
|
328 |
+
new_past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
329 |
+
result = self._orig_forward(
|
330 |
+
input_ids=None,
|
331 |
+
attention_mask=attention_mask,
|
332 |
+
position_ids=position_ids,
|
333 |
+
past_key_values=new_past_key_values,
|
334 |
+
inputs_embeds=inputs_embeds,
|
335 |
+
)
|
336 |
+
if past_key_values is not None:
|
337 |
+
result["past_key_values"] = result["past_key_values"].to_legacy_cache()
|
338 |
+
return tuple(result.values())
|
339 |
+
|
340 |
+
model._orig_forward = model.forward
|
341 |
+
model.forward = types.MethodType(forward_wrap, model)
|
342 |
+
hidden_size = model.config.hidden_size
|
343 |
+
num_pkv = model.config.num_hidden_layers
|
344 |
+
pkv_shape = (2, model.config.num_key_value_heads, 2, hidden_size // model.config.num_attention_heads)
|
345 |
+
cache_position = torch.arange(2, 4)
|
346 |
+
position_ids = cache_position.view(1, 1, -1).expand(3, 2, -1)
|
347 |
+
|
348 |
+
input_embeds = torch.randn((2, 2, hidden_size))
|
349 |
+
attention_mask = torch.ones([2, 4], dtype=torch.long)
|
350 |
+
input_names = ["attention_mask", "position_ids"]
|
351 |
+
output_names = ["logits"]
|
352 |
+
|
353 |
+
past_key_values = []
|
354 |
+
for i in range(num_pkv):
|
355 |
+
kv = [torch.randn(pkv_shape) for _ in range(2)]
|
356 |
+
past_key_values.append(kv)
|
357 |
+
input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"])
|
358 |
+
output_names.extend([f"present.{i}.key", f"present.{i}.value"])
|
359 |
+
input_names.append("inputs_embeds")
|
360 |
+
|
361 |
+
example_input = {"inputs_embeds": input_embeds, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values}
|
362 |
+
|
363 |
+
ov_model = ov.convert_model(
|
364 |
+
model,
|
365 |
+
example_input=example_input,
|
366 |
+
)
|
367 |
+
|
368 |
+
for input, input_name in zip(ov_model.inputs, input_names):
|
369 |
+
input.get_tensor().set_names({input_name})
|
370 |
+
|
371 |
+
for output, output_name in zip(ov_model.outputs, output_names):
|
372 |
+
output.get_tensor().set_names({output_name})
|
373 |
+
patch_stateful(ov_model)
|
374 |
+
print("✅ Language model successfully converted")
|
375 |
+
|
376 |
+
if quantization_config is not None:
|
377 |
+
print(f"⌛ Weights compression with {quantization_config['mode']} mode started")
|
378 |
+
ov_model = nncf.compress_weights(ov_model, **quantization_config)
|
379 |
+
print("✅ Weights compression finished")
|
380 |
+
|
381 |
+
ov.save_model(ov_model, lang_model_path, False)
|
382 |
+
del ov_model
|
383 |
+
cleanup_torchscript_cache()
|
384 |
+
del model
|
385 |
+
gc.collect()
|
386 |
+
print(f"✅ {model_id} model conversion finished. You can find results in {output_dir}")
|
387 |
+
|
388 |
+
|
389 |
+
class OVQwen2VLModel(GenerationMixin):
|
390 |
+
def __init__(self, model_dir, device, ov_config=None):
|
391 |
+
model_dir = Path(model_dir)
|
392 |
+
self.model = core.read_model(model_dir / LANGUAGE_MODEL_NAME)
|
393 |
+
self.image_embed = core.compile_model(model_dir / IMAGE_EMBEDDING_NAME, device, ov_config)
|
394 |
+
self.image_embed_merger = core.compile_model(model_dir / IMAGE_EMBEDDING_MERGER_NAME, device, ov_config)
|
395 |
+
self.embed_tokens = core.compile_model(model_dir / TEXT_EMBEDDING_NAME, device)
|
396 |
+
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
|
397 |
+
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
|
398 |
+
compiled_model = core.compile_model(self.model, device, ov_config)
|
399 |
+
self.request = compiled_model.create_infer_request()
|
400 |
+
self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
401 |
+
self.generation_config = GenerationConfig.from_model_config(self.config)
|
402 |
+
self.main_input_name = "input_ids"
|
403 |
+
self.device = torch.device("cpu")
|
404 |
+
self.num_pkv = 2
|
405 |
+
self._supports_cache_class = False
|
406 |
+
self.next_beam_idx = None
|
407 |
+
self._past_length = None
|
408 |
+
self._rotary_pos_emb = VisionRotaryEmbedding(self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2)
|
409 |
+
|
410 |
+
def can_generate(self):
|
411 |
+
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
|
412 |
+
return True
|
413 |
+
|
414 |
+
def __call__(self, *args, **kwargs) -> CausalLMOutputWithPast:
|
415 |
+
return self.forward(
|
416 |
+
*args,
|
417 |
+
**kwargs,
|
418 |
+
)
|
419 |
+
|
420 |
+
def _reorder_cache(self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
421 |
+
"""
|
422 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
423 |
+
[`~PreTrainedModel.beam_sample`] is called.
|
424 |
+
This is required to match `past_key_values` with the correct beam_idx at every generation step.
|
425 |
+
"""
|
426 |
+
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
|
427 |
+
return past_key_values
|
428 |
+
|
429 |
+
def _get_past_length(self, past_key_values=None):
|
430 |
+
if past_key_values is None:
|
431 |
+
return 0
|
432 |
+
return self._past_length
|
433 |
+
|
434 |
+
def get_rope_index(
|
435 |
+
self,
|
436 |
+
input_ids: torch.LongTensor,
|
437 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
438 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
439 |
+
attention_mask: Optional[torch.Tensor] = None,
|
440 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
441 |
+
"""
|
442 |
+
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
443 |
+
|
444 |
+
Explanation:
|
445 |
+
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
446 |
+
|
447 |
+
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
448 |
+
Examples:
|
449 |
+
input_ids: [T T T T T], here T is for text.
|
450 |
+
temporal position_ids: [0, 1, 2, 3, 4]
|
451 |
+
height position_ids: [0, 1, 2, 3, 4]
|
452 |
+
width position_ids: [0, 1, 2, 3, 4]
|
453 |
+
|
454 |
+
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
455 |
+
and 1D rotary position embeddin for text part.
|
456 |
+
Examples:
|
457 |
+
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
458 |
+
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
459 |
+
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
460 |
+
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
461 |
+
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
462 |
+
text temporal position_ids: [3, 4, 5, 6, 7]
|
463 |
+
text height position_ids: [3, 4, 5, 6, 7]
|
464 |
+
text width position_ids: [3, 4, 5, 6, 7]
|
465 |
+
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
469 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
470 |
+
it.
|
471 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
472 |
+
The temporal, height and width of feature shape of each image in LLM.
|
473 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
474 |
+
The temporal, height and width of feature shape of each video in LLM.
|
475 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
476 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
477 |
+
|
478 |
+
- 1 for tokens that are **not masked**,
|
479 |
+
- 0 for tokens that are **masked**.
|
480 |
+
|
481 |
+
Returns:
|
482 |
+
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
483 |
+
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
484 |
+
"""
|
485 |
+
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
486 |
+
image_token_id = self.config.image_token_id
|
487 |
+
video_token_id = self.config.video_token_id
|
488 |
+
vision_start_token_id = self.config.vision_start_token_id
|
489 |
+
mrope_position_deltas = []
|
490 |
+
if image_grid_thw is not None or video_grid_thw is not None:
|
491 |
+
total_input_ids = input_ids
|
492 |
+
position_ids = torch.ones(3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device)
|
493 |
+
image_index, video_index = 0, 0
|
494 |
+
for i, input_ids in enumerate(total_input_ids):
|
495 |
+
if attention_mask is not None:
|
496 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
497 |
+
image_nums, video_nums = 0, 0
|
498 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
499 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
500 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
501 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
502 |
+
input_tokens = input_ids.tolist()
|
503 |
+
llm_pos_ids_list: list = []
|
504 |
+
st = 0
|
505 |
+
remain_images, remain_videos = image_nums, video_nums
|
506 |
+
for _ in range(image_nums + video_nums):
|
507 |
+
if image_token_id in input_tokens and remain_images > 0:
|
508 |
+
ed_image = input_tokens.index(image_token_id, st)
|
509 |
+
else:
|
510 |
+
ed_image = len(input_tokens) + 1
|
511 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
512 |
+
ed_video = input_tokens.index(video_token_id, st)
|
513 |
+
else:
|
514 |
+
ed_video = len(input_tokens) + 1
|
515 |
+
if ed_image < ed_video:
|
516 |
+
t, h, w = (
|
517 |
+
image_grid_thw[image_index][0],
|
518 |
+
image_grid_thw[image_index][1],
|
519 |
+
image_grid_thw[image_index][2],
|
520 |
+
)
|
521 |
+
image_index += 1
|
522 |
+
remain_images -= 1
|
523 |
+
ed = ed_image
|
524 |
+
else:
|
525 |
+
t, h, w = (
|
526 |
+
video_grid_thw[video_index][0],
|
527 |
+
video_grid_thw[video_index][1],
|
528 |
+
video_grid_thw[video_index][2],
|
529 |
+
)
|
530 |
+
video_index += 1
|
531 |
+
remain_videos -= 1
|
532 |
+
ed = ed_video
|
533 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
534 |
+
t.item(),
|
535 |
+
h.item() // spatial_merge_size,
|
536 |
+
w.item() // spatial_merge_size,
|
537 |
+
)
|
538 |
+
text_len = ed - st
|
539 |
+
|
540 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
541 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
542 |
+
|
543 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
544 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
545 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
546 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
547 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
548 |
+
|
549 |
+
if st < len(input_tokens):
|
550 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
551 |
+
text_len = len(input_tokens) - st
|
552 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
553 |
+
|
554 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
555 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
556 |
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
557 |
+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
558 |
+
return position_ids, mrope_position_deltas
|
559 |
+
else:
|
560 |
+
if attention_mask is not None:
|
561 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
562 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
563 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
564 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
565 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
566 |
+
else:
|
567 |
+
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand(3, input_ids.shape[0], -1)
|
568 |
+
mrope_position_deltas = torch.zeros(
|
569 |
+
[input_ids.shape[0], 1],
|
570 |
+
device=input_ids.device,
|
571 |
+
dtype=input_ids.dtype,
|
572 |
+
)
|
573 |
+
|
574 |
+
return position_ids, mrope_position_deltas
|
575 |
+
|
576 |
+
def _update_model_kwargs_for_generation(
|
577 |
+
self,
|
578 |
+
outputs: ModelOutput,
|
579 |
+
model_kwargs: Dict[str, Any],
|
580 |
+
is_encoder_decoder: bool = False,
|
581 |
+
num_new_tokens: int = 1,
|
582 |
+
) -> Dict[str, Any]:
|
583 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
584 |
+
outputs=outputs,
|
585 |
+
model_kwargs=model_kwargs,
|
586 |
+
is_encoder_decoder=is_encoder_decoder,
|
587 |
+
num_new_tokens=num_new_tokens,
|
588 |
+
)
|
589 |
+
|
590 |
+
if getattr(outputs, "rope_deltas", None) is not None:
|
591 |
+
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
592 |
+
|
593 |
+
return model_kwargs
|
594 |
+
|
595 |
+
def prepare_inputs_for_generation(
|
596 |
+
self,
|
597 |
+
input_ids,
|
598 |
+
past_key_values=None,
|
599 |
+
attention_mask=None,
|
600 |
+
inputs_embeds=None,
|
601 |
+
cache_position=None,
|
602 |
+
position_ids=None,
|
603 |
+
use_cache=True,
|
604 |
+
pixel_values=None,
|
605 |
+
pixel_values_videos=None,
|
606 |
+
image_grid_thw=None,
|
607 |
+
video_grid_thw=None,
|
608 |
+
**kwargs,
|
609 |
+
):
|
610 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
611 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
612 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
613 |
+
if past_key_values is not None:
|
614 |
+
if inputs_embeds is not None: # Exception 1
|
615 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
616 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
617 |
+
input_ids = input_ids[:, cache_position]
|
618 |
+
|
619 |
+
rope_deltas = kwargs.get("rope_deltas", None)
|
620 |
+
if attention_mask is not None and position_ids is None:
|
621 |
+
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
622 |
+
position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
623 |
+
else:
|
624 |
+
batch_size, seq_length = input_ids.shape
|
625 |
+
delta = cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
|
626 |
+
position_ids = torch.arange(seq_length, device=input_ids.device)
|
627 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
628 |
+
position_ids = position_ids.add(delta)
|
629 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
630 |
+
|
631 |
+
if cache_position[0] != 0:
|
632 |
+
pixel_values = None
|
633 |
+
pixel_values_videos = None
|
634 |
+
|
635 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
636 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
637 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
638 |
+
else:
|
639 |
+
model_inputs = {"input_ids": input_ids}
|
640 |
+
|
641 |
+
model_inputs.update(
|
642 |
+
{
|
643 |
+
"position_ids": position_ids,
|
644 |
+
"past_key_values": past_key_values,
|
645 |
+
"use_cache": use_cache,
|
646 |
+
"attention_mask": attention_mask,
|
647 |
+
"pixel_values": pixel_values,
|
648 |
+
"pixel_values_videos": pixel_values_videos,
|
649 |
+
"image_grid_thw": image_grid_thw,
|
650 |
+
"video_grid_thw": video_grid_thw,
|
651 |
+
"rope_deltas": rope_deltas,
|
652 |
+
}
|
653 |
+
)
|
654 |
+
return model_inputs
|
655 |
+
|
656 |
+
def forward(
|
657 |
+
self,
|
658 |
+
input_ids: torch.LongTensor = None,
|
659 |
+
attention_mask: Optional[torch.Tensor] = None,
|
660 |
+
position_ids: Optional[torch.LongTensor] = None,
|
661 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
662 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
663 |
+
use_cache: Optional[bool] = None,
|
664 |
+
output_attentions: Optional[bool] = None,
|
665 |
+
output_hidden_states: Optional[bool] = None,
|
666 |
+
return_dict: Optional[bool] = None,
|
667 |
+
pixel_values: Optional[torch.Tensor] = None,
|
668 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
669 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
670 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
671 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
672 |
+
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
673 |
+
r"""
|
674 |
+
Args:.to(inputs_embeds.device)
|
675 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
676 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
677 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
678 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
679 |
+
|
680 |
+
Returns:
|
681 |
+
|
682 |
+
Example:
|
683 |
+
|
684 |
+
```python
|
685 |
+
>>> from PIL import Image
|
686 |
+
>>> import requests
|
687 |
+
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
688 |
+
|
689 |
+
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
690 |
+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
691 |
+
|
692 |
+
>>> messages = [
|
693 |
+
{
|
694 |
+
"role": "user",
|
695 |
+
"content": [
|
696 |
+
{"type": "image"},
|
697 |
+
{"type": "text", "text": "What is shown in this image?"},
|
698 |
+
],
|
699 |
+
},
|
700 |
+
]
|
701 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
702 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
703 |
+
|
704 |
+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
705 |
+
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
706 |
+
|
707 |
+
>>> # Generate
|
708 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
709 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
710 |
+
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
711 |
+
```"""
|
712 |
+
if inputs_embeds is None:
|
713 |
+
inputs_embeds = self.embed_tokens(input_ids)[0]
|
714 |
+
if pixel_values is not None:
|
715 |
+
pixel_values = pixel_values
|
716 |
+
image_embeds = self.visual(pixel_values, image_grid_thw)
|
717 |
+
image_mask = input_ids == self.config.image_token_id
|
718 |
+
inputs_embeds[image_mask] = image_embeds
|
719 |
+
if pixel_values_videos is not None:
|
720 |
+
pixel_values_videos = pixel_values_videos
|
721 |
+
video_embeds = self.visual(pixel_values_videos, video_grid_thw)
|
722 |
+
video_mask = input_ids == self.config.video_token_id
|
723 |
+
inputs_embeds[video_mask] = video_embeds
|
724 |
+
if attention_mask is not None:
|
725 |
+
attention_mask = attention_mask
|
726 |
+
if past_key_values is None:
|
727 |
+
self.request.reset_state()
|
728 |
+
self.next_beam_idx = np.arange(inputs_embeds.shape[0], dtype=int)
|
729 |
+
self._past_length = 0
|
730 |
+
inputs = {}
|
731 |
+
inputs["inputs_embeds"] = inputs_embeds
|
732 |
+
inputs["attention_mask"] = attention_mask
|
733 |
+
inputs["position_ids"] = position_ids
|
734 |
+
if "beam_idx" in self.input_names:
|
735 |
+
inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int)
|
736 |
+
self.request.start_async(inputs, share_inputs=True)
|
737 |
+
self.request.wait()
|
738 |
+
logits = self.request.get_tensor("logits").data
|
739 |
+
logits = torch.from_numpy(logits).to(self.device)
|
740 |
+
past_key_values = ((),)
|
741 |
+
self._past_length += inputs["inputs_embeds"].shape[1]
|
742 |
+
|
743 |
+
return Qwen2VLCausalLMOutputWithPast(
|
744 |
+
loss=None,
|
745 |
+
logits=logits,
|
746 |
+
past_key_values=past_key_values,
|
747 |
+
rope_deltas=rope_deltas,
|
748 |
+
)
|
749 |
+
|
750 |
+
def rot_pos_emb(self, grid_thw):
|
751 |
+
pos_ids = []
|
752 |
+
for t, h, w in grid_thw:
|
753 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
754 |
+
hpos_ids = hpos_ids.reshape(
|
755 |
+
h // self.config.vision_config.spatial_merge_size,
|
756 |
+
self.config.vision_config.spatial_merge_size,
|
757 |
+
w // self.config.vision_config.spatial_merge_size,
|
758 |
+
self.config.vision_config.spatial_merge_size,
|
759 |
+
)
|
760 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
761 |
+
hpos_ids = hpos_ids.flatten()
|
762 |
+
|
763 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
764 |
+
wpos_ids = wpos_ids.reshape(
|
765 |
+
h // self.config.vision_config.spatial_merge_size,
|
766 |
+
self.config.vision_config.spatial_merge_size,
|
767 |
+
w // self.config.vision_config.spatial_merge_size,
|
768 |
+
self.config.vision_config.spatial_merge_size,
|
769 |
+
)
|
770 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
771 |
+
wpos_ids = wpos_ids.flatten()
|
772 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
773 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
774 |
+
max_grid_size = grid_thw[:, 1:].max()
|
775 |
+
rotary_pos_emb_full = self._rotary_pos_emb(max_grid_size)
|
776 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
777 |
+
return rotary_pos_emb
|
778 |
+
|
779 |
+
def visual(self, hidden_states, grid_thw):
|
780 |
+
hidden_states = self.image_embed(hidden_states)[0]
|
781 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
782 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
783 |
+
cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
|
784 |
+
attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool)
|
785 |
+
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
|
786 |
+
for i in range(1, len(cu_seqlens)):
|
787 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
788 |
+
|
789 |
+
causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf"))
|
790 |
+
|
791 |
+
res = self.image_embed_merger([hidden_states, causal_mask, rotary_pos_emb])[0]
|
792 |
+
return res
|
qwen2-build.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
6 |
+
|
7 |
+
if not Path("ov_qwen2_vl.py").exists():
|
8 |
+
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/qwen2-vl/ov_qwen2_vl.py")
|
9 |
+
open("ov_qwen2_vl.py", "w").write(r.text)
|
10 |
+
|
11 |
+
if not Path("notebook_utils.py").exists():
|
12 |
+
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py")
|
13 |
+
open("notebook_utils.py", "w").write(r.text)
|
14 |
+
|
15 |
+
|
16 |
+
from ov_qwen2_vl import model_selector
|
17 |
+
|
18 |
+
model_id = model_selector()
|
19 |
+
|
20 |
+
print(model_id)
|
21 |
+
|
22 |
+
|
23 |
+
print(f"Selected {model_id.value}")
|
24 |
+
pt_model_id = model_id.value
|
25 |
+
model_dir = Path(pt_model_id.split("/")[-1])
|
26 |
+
|
27 |
+
|
28 |
+
from ov_qwen2_vl import convert_qwen2vl_model
|
29 |
+
|
30 |
+
# uncomment these lines to see model conversion code
|
31 |
+
# convert_qwen2vl_model??
|
32 |
+
|
33 |
+
|
34 |
+
import nncf
|
35 |
+
|
36 |
+
compression_configuration = {
|
37 |
+
"mode": nncf.CompressWeightsMode.INT4_ASYM,
|
38 |
+
"group_size": 32,
|
39 |
+
"ratio": 1.0,
|
40 |
+
}
|
41 |
+
|
42 |
+
convert_qwen2vl_model(pt_model_id, model_dir, compression_configuration)
|
43 |
+
|
qwen2vl.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|