|
from pathlib import Path
|
|
import types
|
|
from typing import Optional, Tuple, Union, List, Dict, Any
|
|
import gc
|
|
import openvino as ov
|
|
from openvino.runtime import opset13
|
|
import nncf
|
|
import numpy as np
|
|
import torch
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoConfig
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, VisionRotaryEmbedding
|
|
from transformers.cache_utils import DynamicCache
|
|
from transformers.modeling_outputs import ModelOutput
|
|
from transformers.generation import GenerationConfig, GenerationMixin
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
model_ids = ["Qwen/Qwen2-VL-2B-Instruct", "Qwen/Qwen2-VL-7B-Instruct"]
|
|
|
|
|
|
def model_selector(default=model_ids[0]):
|
|
import ipywidgets as widgets
|
|
|
|
model_checkpoint = widgets.Dropdown(
|
|
options=model_ids,
|
|
default=default,
|
|
description="Model:",
|
|
)
|
|
return model_checkpoint
|
|
|
|
|
|
def model_has_state(ov_model: ov.Model):
|
|
return len(ov_model.get_sinks()) > 0
|
|
|
|
|
|
def model_has_input_output_name(ov_model: ov.Model, name: str):
|
|
"""
|
|
Helper function for checking that model has specified input or output name
|
|
|
|
Parameters:
|
|
ov_model (ov.Model):
|
|
name (str):
|
|
name of input or output
|
|
|
|
Returns:
|
|
True if input or output with requested name exists else False
|
|
"""
|
|
return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [])
|
|
|
|
|
|
def fuse_cache_reorder(
|
|
ov_model: ov.Model,
|
|
not_kv_inputs: List[str],
|
|
key_value_input_names: List[str],
|
|
gather_dim: int,
|
|
):
|
|
"""
|
|
Fuses reored_cache during generate cycle into ov.Model. Used with stateful models, because we can not modify model state directly.
|
|
|
|
Adds a new beam_idx parameter and Gather op per each kv-cache input in a given model.
|
|
Should be run before make_stateful. Implements optimumum's _reorder_cache
|
|
inside the model in the beginning of each iteration.
|
|
Gather works along given gather_dim dimension that may vary from model to model.
|
|
KV-cache inputs are identified based on names in key_value_input_names.
|
|
Append the new beam_idx parameter to not_kv_inputs.
|
|
|
|
Parameters:
|
|
ov_model (`ov.Model`):
|
|
openvino model for processing
|
|
not_kv_inputs (`List[str]`):
|
|
list of input nodes in model that not related to past key values
|
|
key_value_input_names (`List[str]`):
|
|
list of names for key value input layers
|
|
gather_dim (int):
|
|
dimension for gathering cache during reorder pass
|
|
"""
|
|
|
|
if model_has_input_output_name(ov_model, "beam_idx"):
|
|
raise ValueError("Model already has fused cache")
|
|
input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0]
|
|
beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]))
|
|
beam_idx.output(0).get_tensor().add_names({"beam_idx"})
|
|
ov_model.add_parameters([beam_idx])
|
|
not_kv_inputs.append(ov_model.inputs[-1])
|
|
|
|
for input_name in key_value_input_names:
|
|
parameter_output_port = ov_model.input(input_name)
|
|
consumers = parameter_output_port.get_target_inputs()
|
|
gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim))
|
|
for consumer in consumers:
|
|
consumer.replace_source_output(gather.output(0))
|
|
ov_model.validate_nodes_and_infer_types()
|
|
|
|
|
|
def build_state_initializer(ov_model: ov.Model, batch_dim: int):
|
|
"""
|
|
Build initialization ShapeOf Expression for all ReadValue ops
|
|
|
|
Parameters:
|
|
ov_model (ov.Model):
|
|
openvino model
|
|
batch_dim (int):
|
|
index of dimension corresponding to batch size
|
|
"""
|
|
input_ids = ov_model.input("inputs_embeds")
|
|
batch = opset13.gather(
|
|
opset13.shape_of(input_ids, output_type="i64"),
|
|
opset13.constant([0]),
|
|
opset13.constant(0),
|
|
)
|
|
for op in ov_model.get_ops():
|
|
if op.get_type_name() == "ReadValue":
|
|
dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))]
|
|
dims[batch_dim] = batch
|
|
dims = [(opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims]
|
|
shape = opset13.concat(dims, axis=0)
|
|
broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape)
|
|
op.set_arguments([broadcast])
|
|
ov_model.validate_nodes_and_infer_types()
|
|
|
|
|
|
def make_stateful(
|
|
ov_model: ov.Model,
|
|
not_kv_inputs: List[str],
|
|
key_value_input_names: List[str],
|
|
key_value_output_names: List[str],
|
|
batch_dim: int,
|
|
num_attention_heads: int,
|
|
num_beams_and_batch: int = None,
|
|
):
|
|
"""
|
|
Hides kv-cache inputs and outputs inside the model as variables.
|
|
|
|
Parameters:
|
|
ov_model (ov.Model):
|
|
openvino model
|
|
not_kv_inputs (`List[str]`):
|
|
list of input nodes in model that not related to past key values
|
|
key_value_input_names (`List[str]`):
|
|
list of names for key value input layers
|
|
key_value_output_names (`List[str]`):
|
|
list of names for key value input layers
|
|
batch_dim (int):
|
|
index of batch dimension in key value layers
|
|
num_attention_heads (int):
|
|
number of attention heads for batch dimension initialization
|
|
num_beams_an_batch (int):
|
|
precalculated number of beams and batch for shapes initialization
|
|
"""
|
|
from openvino._offline_transformations import apply_make_stateful_transformation
|
|
|
|
input_output_map = {}
|
|
|
|
if num_beams_and_batch is not None:
|
|
|
|
for input in not_kv_inputs:
|
|
shape = input.get_partial_shape()
|
|
if shape.rank.get_length() <= 2:
|
|
shape[0] = num_beams_and_batch
|
|
input.get_node().set_partial_shape(shape)
|
|
for kv_name_pair in zip(key_value_input_names, key_value_output_names):
|
|
input_output_map[kv_name_pair[0]] = kv_name_pair[1]
|
|
if num_beams_and_batch is not None:
|
|
input = ov_model.input(kv_name_pair[0])
|
|
shape = input.get_partial_shape()
|
|
shape[batch_dim] = num_beams_and_batch * num_attention_heads
|
|
input.get_node().set_partial_shape(shape)
|
|
|
|
if num_beams_and_batch is not None:
|
|
|
|
ov_model.validate_nodes_and_infer_types()
|
|
|
|
apply_make_stateful_transformation(ov_model, input_output_map)
|
|
if num_beams_and_batch is None:
|
|
build_state_initializer(ov_model, batch_dim)
|
|
|
|
|
|
def patch_stateful(ov_model):
|
|
key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]]
|
|
key_value_output_names = [key.get_any_name() for key in ov_model.outputs[1:]]
|
|
not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())]
|
|
if not key_value_input_names or not key_value_output_names:
|
|
return
|
|
batch_dim = 0
|
|
num_attention_heads = 1
|
|
|
|
fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim)
|
|
make_stateful(
|
|
ov_model,
|
|
not_kv_inputs,
|
|
key_value_input_names,
|
|
key_value_output_names,
|
|
batch_dim,
|
|
num_attention_heads,
|
|
None,
|
|
)
|
|
|
|
|
|
core = ov.Core()
|
|
|
|
|
|
def cleanup_torchscript_cache():
|
|
"""
|
|
Helper for removing cached model representation
|
|
"""
|
|
torch._C._jit_clear_class_registry()
|
|
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
|
|
torch.jit._state._clear_class_state()
|
|
|
|
|
|
LANGUAGE_MODEL_NAME = "openvino_language_model.xml"
|
|
IMAGE_EMBEDDING_NAME = "openvino_vision_embeddings_model.xml"
|
|
IMAGE_EMBEDDING_MERGER_NAME = "openvino_vision_embeddings_merger_model.xml"
|
|
TEXT_EMBEDDING_NAME = "openvino_text_embeddings_model.xml"
|
|
|
|
|
|
def convert_qwen2vl_model(model_id, output_dir, quantization_config):
|
|
output_dir = Path(output_dir)
|
|
|
|
lang_model_path = output_dir / LANGUAGE_MODEL_NAME
|
|
image_embed_path = output_dir / IMAGE_EMBEDDING_NAME
|
|
embed_token_path = output_dir / TEXT_EMBEDDING_NAME
|
|
image_embed_merger_path = output_dir / IMAGE_EMBEDDING_MERGER_NAME
|
|
|
|
if all(
|
|
[
|
|
lang_model_path.exists(),
|
|
image_embed_path.exists(),
|
|
image_embed_merger_path.exists(),
|
|
embed_token_path.exists(),
|
|
]
|
|
):
|
|
print(f"✅ {model_id} model already converted. You can find results in {output_dir}")
|
|
return
|
|
print(f"⌛ {model_id} conversion started. Be patient, it may takes some time.")
|
|
print("⌛ Load Original model")
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(model_id)
|
|
processor = AutoProcessor.from_pretrained(model_id)
|
|
model.config.save_pretrained(output_dir)
|
|
processor.save_pretrained(output_dir)
|
|
print("✅ Original model successfully loaded")
|
|
|
|
if not embed_token_path.exists():
|
|
print("⌛ Convert Input embedding model")
|
|
ov_model = ov.convert_model(
|
|
model.model.embed_tokens,
|
|
example_input=torch.ones([2, 2], dtype=torch.int64),
|
|
)
|
|
ov.save_model(ov_model, embed_token_path)
|
|
del ov_model
|
|
cleanup_torchscript_cache()
|
|
gc.collect()
|
|
print("✅ Input embedding model successfully converted")
|
|
|
|
if not image_embed_path.exists() or not image_embed_merger_path.exists():
|
|
print("⌛ Convert Image embedding model")
|
|
|
|
vision_embed_tokens = model.visual
|
|
if not image_embed_path.exists():
|
|
ov_model = ov.convert_model(vision_embed_tokens.patch_embed, example_input={"hidden_states": torch.randn([4988, 1176])})
|
|
ov.save_model(ov_model, image_embed_path)
|
|
del ov_model
|
|
cleanup_torchscript_cache()
|
|
|
|
def image_embed_forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor) -> torch.Tensor:
|
|
for blk in self.blocks:
|
|
hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
|
|
|
|
return self.merger(hidden_states)
|
|
|
|
def sdpa_attn_forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None) -> torch.Tensor:
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
|
|
|
|
seq_length = hidden_states.shape[0]
|
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
|
|
|
q = q.transpose(0, 1)
|
|
k = k.transpose(0, 1)
|
|
v = v.transpose(0, 1)
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
|
attn_output = attn_output.transpose(0, 1)
|
|
attn_output = attn_output.reshape(seq_length, -1)
|
|
attn_output = self.proj(attn_output)
|
|
return attn_output
|
|
|
|
def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor:
|
|
hidden_states = hidden_states + self.attn(self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
|
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
return hidden_states
|
|
|
|
if not image_embed_merger_path.exists():
|
|
vision_embed_tokens.forward = types.MethodType(image_embed_forward, vision_embed_tokens)
|
|
for block in vision_embed_tokens.blocks:
|
|
block.forward = types.MethodType(block_forward, block)
|
|
block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn)
|
|
|
|
ov_model = ov.convert_model(
|
|
vision_embed_tokens,
|
|
example_input={
|
|
"hidden_states": torch.randn([4988, 1280]),
|
|
"attention_mask": torch.ones([1, 4988, 4988]),
|
|
"rotary_pos_emb": torch.randn([4988, 40]),
|
|
},
|
|
)
|
|
if quantization_config is not None:
|
|
print(f"⌛ Weights compression with {quantization_config['mode']} mode started")
|
|
ov_model = nncf.compress_weights(ov_model, **quantization_config)
|
|
print("✅ Weights compression finished")
|
|
|
|
ov.save_model(ov_model, image_embed_merger_path)
|
|
del ov_model
|
|
cleanup_torchscript_cache()
|
|
del vision_embed_tokens
|
|
gc.collect()
|
|
print("✅ Image embedding model successfully converted")
|
|
|
|
if not lang_model_path.exists():
|
|
print("⌛ Convert Language model")
|
|
|
|
def forward_wrap(
|
|
self,
|
|
attention_mask,
|
|
position_ids=None,
|
|
past_key_values=None,
|
|
inputs_embeds=None,
|
|
):
|
|
new_past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
result = self._orig_forward(
|
|
input_ids=None,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=new_past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
if past_key_values is not None:
|
|
result["past_key_values"] = result["past_key_values"].to_legacy_cache()
|
|
return tuple(result.values())
|
|
|
|
model._orig_forward = model.forward
|
|
model.forward = types.MethodType(forward_wrap, model)
|
|
hidden_size = model.config.hidden_size
|
|
num_pkv = model.config.num_hidden_layers
|
|
pkv_shape = (2, model.config.num_key_value_heads, 2, hidden_size // model.config.num_attention_heads)
|
|
cache_position = torch.arange(2, 4)
|
|
position_ids = cache_position.view(1, 1, -1).expand(3, 2, -1)
|
|
|
|
input_embeds = torch.randn((2, 2, hidden_size))
|
|
attention_mask = torch.ones([2, 4], dtype=torch.long)
|
|
input_names = ["attention_mask", "position_ids"]
|
|
output_names = ["logits"]
|
|
|
|
past_key_values = []
|
|
for i in range(num_pkv):
|
|
kv = [torch.randn(pkv_shape) for _ in range(2)]
|
|
past_key_values.append(kv)
|
|
input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"])
|
|
output_names.extend([f"present.{i}.key", f"present.{i}.value"])
|
|
input_names.append("inputs_embeds")
|
|
|
|
example_input = {"inputs_embeds": input_embeds, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values}
|
|
|
|
ov_model = ov.convert_model(
|
|
model,
|
|
example_input=example_input,
|
|
)
|
|
|
|
for input, input_name in zip(ov_model.inputs, input_names):
|
|
input.get_tensor().set_names({input_name})
|
|
|
|
for output, output_name in zip(ov_model.outputs, output_names):
|
|
output.get_tensor().set_names({output_name})
|
|
patch_stateful(ov_model)
|
|
print("✅ Language model successfully converted")
|
|
|
|
if quantization_config is not None:
|
|
print(f"⌛ Weights compression with {quantization_config['mode']} mode started")
|
|
ov_model = nncf.compress_weights(ov_model, **quantization_config)
|
|
print("✅ Weights compression finished")
|
|
|
|
ov.save_model(ov_model, lang_model_path, False)
|
|
del ov_model
|
|
cleanup_torchscript_cache()
|
|
del model
|
|
gc.collect()
|
|
print(f"✅ {model_id} model conversion finished. You can find results in {output_dir}")
|
|
|
|
|
|
class OVQwen2VLModel(GenerationMixin):
|
|
def __init__(self, model_dir, device, ov_config=None):
|
|
model_dir = Path(model_dir)
|
|
self.model = core.read_model(model_dir / LANGUAGE_MODEL_NAME)
|
|
self.image_embed = core.compile_model(model_dir / IMAGE_EMBEDDING_NAME, device, ov_config)
|
|
self.image_embed_merger = core.compile_model(model_dir / IMAGE_EMBEDDING_MERGER_NAME, device, ov_config)
|
|
self.embed_tokens = core.compile_model(model_dir / TEXT_EMBEDDING_NAME, device)
|
|
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
|
|
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
|
|
compiled_model = core.compile_model(self.model, device, ov_config)
|
|
self.request = compiled_model.create_infer_request()
|
|
self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
|
self.generation_config = GenerationConfig.from_model_config(self.config)
|
|
self.main_input_name = "input_ids"
|
|
self.device = torch.device("cpu")
|
|
self.num_pkv = 2
|
|
self._supports_cache_class = False
|
|
self.next_beam_idx = None
|
|
self._past_length = None
|
|
self._rotary_pos_emb = VisionRotaryEmbedding(self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2)
|
|
|
|
def can_generate(self):
|
|
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
|
|
return True
|
|
|
|
def __call__(self, *args, **kwargs) -> CausalLMOutputWithPast:
|
|
return self.forward(
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
def _reorder_cache(self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
|
"""
|
|
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
|
[`~PreTrainedModel.beam_sample`] is called.
|
|
This is required to match `past_key_values` with the correct beam_idx at every generation step.
|
|
"""
|
|
self.next_beam_idx = np.array(beam_idx)
|
|
return past_key_values
|
|
|
|
def _get_past_length(self, past_key_values=None):
|
|
if past_key_values is None:
|
|
return 0
|
|
return self._past_length
|
|
|
|
def get_rope_index(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
|
|
|
Explanation:
|
|
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
|
|
|
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
|
Examples:
|
|
input_ids: [T T T T T], here T is for text.
|
|
temporal position_ids: [0, 1, 2, 3, 4]
|
|
height position_ids: [0, 1, 2, 3, 4]
|
|
width position_ids: [0, 1, 2, 3, 4]
|
|
|
|
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
|
and 1D rotary position embeddin for text part.
|
|
Examples:
|
|
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
|
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
|
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
|
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
|
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
|
text temporal position_ids: [3, 4, 5, 6, 7]
|
|
text height position_ids: [3, 4, 5, 6, 7]
|
|
text width position_ids: [3, 4, 5, 6, 7]
|
|
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
|
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each image in LLM.
|
|
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
|
The temporal, height and width of feature shape of each video in LLM.
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
Returns:
|
|
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
|
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
|
"""
|
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
|
image_token_id = self.config.image_token_id
|
|
video_token_id = self.config.video_token_id
|
|
vision_start_token_id = self.config.vision_start_token_id
|
|
mrope_position_deltas = []
|
|
if image_grid_thw is not None or video_grid_thw is not None:
|
|
total_input_ids = input_ids
|
|
position_ids = torch.ones(3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device)
|
|
image_index, video_index = 0, 0
|
|
for i, input_ids in enumerate(total_input_ids):
|
|
if attention_mask is not None:
|
|
input_ids = input_ids[attention_mask[i] == 1]
|
|
image_nums, video_nums = 0, 0
|
|
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
|
vision_tokens = input_ids[vision_start_indices + 1]
|
|
image_nums = (vision_tokens == image_token_id).sum()
|
|
video_nums = (vision_tokens == video_token_id).sum()
|
|
input_tokens = input_ids.tolist()
|
|
llm_pos_ids_list: list = []
|
|
st = 0
|
|
remain_images, remain_videos = image_nums, video_nums
|
|
for _ in range(image_nums + video_nums):
|
|
if image_token_id in input_tokens and remain_images > 0:
|
|
ed_image = input_tokens.index(image_token_id, st)
|
|
else:
|
|
ed_image = len(input_tokens) + 1
|
|
if video_token_id in input_tokens and remain_videos > 0:
|
|
ed_video = input_tokens.index(video_token_id, st)
|
|
else:
|
|
ed_video = len(input_tokens) + 1
|
|
if ed_image < ed_video:
|
|
t, h, w = (
|
|
image_grid_thw[image_index][0],
|
|
image_grid_thw[image_index][1],
|
|
image_grid_thw[image_index][2],
|
|
)
|
|
image_index += 1
|
|
remain_images -= 1
|
|
ed = ed_image
|
|
else:
|
|
t, h, w = (
|
|
video_grid_thw[video_index][0],
|
|
video_grid_thw[video_index][1],
|
|
video_grid_thw[video_index][2],
|
|
)
|
|
video_index += 1
|
|
remain_videos -= 1
|
|
ed = ed_video
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
t.item(),
|
|
h.item() // spatial_merge_size,
|
|
w.item() // spatial_merge_size,
|
|
)
|
|
text_len = ed - st
|
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
|
|
if st < len(input_tokens):
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
text_len = len(input_tokens) - st
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
|
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
|
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
|
return position_ids, mrope_position_deltas
|
|
else:
|
|
if attention_mask is not None:
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
else:
|
|
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, 1, -1).expand(3, input_ids.shape[0], -1)
|
|
mrope_position_deltas = torch.zeros(
|
|
[input_ids.shape[0], 1],
|
|
device=input_ids.device,
|
|
dtype=input_ids.dtype,
|
|
)
|
|
|
|
return position_ids, mrope_position_deltas
|
|
|
|
def _update_model_kwargs_for_generation(
|
|
self,
|
|
outputs: ModelOutput,
|
|
model_kwargs: Dict[str, Any],
|
|
is_encoder_decoder: bool = False,
|
|
num_new_tokens: int = 1,
|
|
) -> Dict[str, Any]:
|
|
model_kwargs = super()._update_model_kwargs_for_generation(
|
|
outputs=outputs,
|
|
model_kwargs=model_kwargs,
|
|
is_encoder_decoder=is_encoder_decoder,
|
|
num_new_tokens=num_new_tokens,
|
|
)
|
|
|
|
if getattr(outputs, "rope_deltas", None) is not None:
|
|
model_kwargs["rope_deltas"] = outputs.rope_deltas
|
|
|
|
return model_kwargs
|
|
|
|
def prepare_inputs_for_generation(
|
|
self,
|
|
input_ids,
|
|
past_key_values=None,
|
|
attention_mask=None,
|
|
inputs_embeds=None,
|
|
cache_position=None,
|
|
position_ids=None,
|
|
use_cache=True,
|
|
pixel_values=None,
|
|
pixel_values_videos=None,
|
|
image_grid_thw=None,
|
|
video_grid_thw=None,
|
|
**kwargs,
|
|
):
|
|
|
|
|
|
|
|
if past_key_values is not None:
|
|
if inputs_embeds is not None:
|
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
|
elif input_ids.shape[1] != cache_position.shape[0]:
|
|
input_ids = input_ids[:, cache_position]
|
|
|
|
rope_deltas = kwargs.get("rope_deltas", None)
|
|
if attention_mask is not None and position_ids is None:
|
|
if cache_position is None or (cache_position is not None and cache_position[0] == 0):
|
|
position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
|
else:
|
|
batch_size, seq_length = input_ids.shape
|
|
delta = cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0
|
|
position_ids = torch.arange(seq_length, device=input_ids.device)
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
position_ids = position_ids.add(delta)
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
|
|
if cache_position[0] != 0:
|
|
pixel_values = None
|
|
pixel_values_videos = None
|
|
|
|
|
|
if inputs_embeds is not None and cache_position[0] == 0:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
model_inputs.update(
|
|
{
|
|
"position_ids": position_ids,
|
|
"past_key_values": past_key_values,
|
|
"use_cache": use_cache,
|
|
"attention_mask": attention_mask,
|
|
"pixel_values": pixel_values,
|
|
"pixel_values_videos": pixel_values_videos,
|
|
"image_grid_thw": image_grid_thw,
|
|
"video_grid_thw": video_grid_thw,
|
|
"rope_deltas": rope_deltas,
|
|
}
|
|
)
|
|
return model_inputs
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
pixel_values: Optional[torch.Tensor] = None,
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
r"""
|
|
Args:.to(inputs_embeds.device)
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from PIL import Image
|
|
>>> import requests
|
|
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
|
|
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
|
|
>>> messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image"},
|
|
{"type": "text", "text": "What is shown in this image?"},
|
|
],
|
|
},
|
|
]
|
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
```"""
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)[0]
|
|
if pixel_values is not None:
|
|
pixel_values = pixel_values
|
|
image_embeds = self.visual(pixel_values, image_grid_thw)
|
|
image_mask = input_ids == self.config.image_token_id
|
|
inputs_embeds[image_mask] = image_embeds
|
|
if pixel_values_videos is not None:
|
|
pixel_values_videos = pixel_values_videos
|
|
video_embeds = self.visual(pixel_values_videos, video_grid_thw)
|
|
video_mask = input_ids == self.config.video_token_id
|
|
inputs_embeds[video_mask] = video_embeds
|
|
if attention_mask is not None:
|
|
attention_mask = attention_mask
|
|
if past_key_values is None:
|
|
self.request.reset_state()
|
|
self.next_beam_idx = np.arange(inputs_embeds.shape[0], dtype=int)
|
|
self._past_length = 0
|
|
inputs = {}
|
|
inputs["inputs_embeds"] = inputs_embeds
|
|
inputs["attention_mask"] = attention_mask
|
|
inputs["position_ids"] = position_ids
|
|
if "beam_idx" in self.input_names:
|
|
inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int)
|
|
self.request.start_async(inputs, share_inputs=True)
|
|
self.request.wait()
|
|
logits = self.request.get_tensor("logits").data
|
|
logits = torch.from_numpy(logits).to(self.device)
|
|
past_key_values = ((),)
|
|
self._past_length += inputs["inputs_embeds"].shape[1]
|
|
|
|
return Qwen2VLCausalLMOutputWithPast(
|
|
loss=None,
|
|
logits=logits,
|
|
past_key_values=past_key_values,
|
|
rope_deltas=rope_deltas,
|
|
)
|
|
|
|
def rot_pos_emb(self, grid_thw):
|
|
pos_ids = []
|
|
for t, h, w in grid_thw:
|
|
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
|
hpos_ids = hpos_ids.reshape(
|
|
h // self.config.vision_config.spatial_merge_size,
|
|
self.config.vision_config.spatial_merge_size,
|
|
w // self.config.vision_config.spatial_merge_size,
|
|
self.config.vision_config.spatial_merge_size,
|
|
)
|
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
|
hpos_ids = hpos_ids.flatten()
|
|
|
|
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
|
wpos_ids = wpos_ids.reshape(
|
|
h // self.config.vision_config.spatial_merge_size,
|
|
self.config.vision_config.spatial_merge_size,
|
|
w // self.config.vision_config.spatial_merge_size,
|
|
self.config.vision_config.spatial_merge_size,
|
|
)
|
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
|
wpos_ids = wpos_ids.flatten()
|
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
max_grid_size = grid_thw[:, 1:].max()
|
|
rotary_pos_emb_full = self._rotary_pos_emb(max_grid_size)
|
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
|
return rotary_pos_emb
|
|
|
|
def visual(self, hidden_states, grid_thw):
|
|
hidden_states = self.image_embed(hidden_states)[0]
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
|
cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
|
|
attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool)
|
|
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
|
|
for i in range(1, len(cu_seqlens)):
|
|
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
|
|
|
causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf"))
|
|
|
|
res = self.image_embed_merger([hidden_states, causal_mask, rotary_pos_emb])[0]
|
|
return res |