|
|
|
import os |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from timm.models.layers import trunc_normal_ |
|
from contextlib import suppress |
|
import logging |
|
from einops import rearrange |
|
from peft import LoraConfig, get_peft_model |
|
from bigmodelvis import Visualization |
|
|
|
from .clip_encoder_hd import CLIPVisionTowerHD |
|
from .conversation import get_conv_template |
|
from .processors_conv import preprocess_qwen |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel |
|
from transformers.generation import GenerationConfig |
|
from transformers import Qwen2Config, Qwen2ForCausalLM |
|
|
|
|
|
def get_autocast(precision, cache_enabled=True): |
|
if precision == "amp_bfloat16" or precision == "amp_bf16" or precision == 'bf16': |
|
|
|
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled) |
|
elif precision == 'fp16': |
|
return lambda: torch.cuda.amp.autocast(dtype=torch.float16, cache_enabled=cache_enabled) |
|
elif precision == 'fp32': |
|
return suppress |
|
else: |
|
raise ValueError('not supported precision: {}'.format(precision)) |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
class InfMLLM_Unified_HD_Chat(PreTrainedModel): |
|
|
|
def __init__(self, config, debug=False): |
|
super().__init__(config) |
|
|
|
|
|
self.lm_tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, use_fast=False, trust_remote_code=True) |
|
self.media_token_img = "<|image|>" |
|
self.media_token_id_img = self.lm_tokenizer(self.media_token_img, return_tensors="pt",add_special_tokens=False).input_ids.item() |
|
self.lm_model = Qwen2ForCausalLM(config.lm_config) |
|
|
|
self.lm_tokenizer.model_max_length = config.max_txt_len |
|
|
|
self.template_name = config.conv_style |
|
self.preprocess_function = preprocess_qwen |
|
|
|
self.separate = nn.Parameter(torch.zeros([1, 1, 4096])) |
|
self.newline = nn.Parameter(torch.zeros([1, 1, 1, 4096])) |
|
|
|
|
|
self.encoder_img = CLIPVisionTowerHD(config.vision_config, vision_select_layer=-2) |
|
self.encoder_img_ln = lambda x: x |
|
|
|
self.adapter_img = nn.Sequential( |
|
nn.Linear(self.encoder_img.num_features*4, self.lm_model.config.hidden_size), |
|
nn.GELU(), |
|
nn.Linear(self.lm_model.config.hidden_size, self.lm_model.config.hidden_size) |
|
) |
|
|
|
|
|
self.config = config |
|
self.precision = config.precision |
|
self._apply_lemmatizer = getattr(config, 'apply_lemmatizer', False) |
|
self._lemmatizer = None |
|
|
|
|
|
def forward_encoder_img(self, image): |
|
autocast = get_autocast(self.precision, cache_enabled=True) |
|
with autocast(): |
|
assert isinstance(image, list) |
|
image_embeds, image_split = self.encoder_img(image, self.separate, self.newline) |
|
|
|
image_embeds = self.encoder_img_ln(image_embeds) |
|
image_embeds = self.adapter_img(image_embeds) |
|
return image_embeds, image_split |
|
|
|
def _concat_embeds(self, |
|
prompt_embeds, prompt_ids, prompt_masks, |
|
labels=None, padding='left'): |
|
emb_lens = [len(emb) for emb in prompt_embeds] |
|
if len(set(emb_lens)) == 1: |
|
if labels is not None: |
|
return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0), torch.stack(labels, dim=0) |
|
return torch.stack(prompt_embeds, dim=0), torch.stack(prompt_ids, dim=0), torch.stack(prompt_masks, dim=0) |
|
|
|
|
|
pad_emb = self.lm_model.get_input_embeddings()(torch.tensor(self.lm_tokenizer.pad_token_id, device=prompt_embeds[0].device)) |
|
|
|
prompt_embeds_new = pad_emb.expand(len(emb_lens), max(emb_lens), -1).clone() |
|
prompt_ids_new = torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) * self.lm_tokenizer.pad_token_id |
|
prompt_masks_new = torch.zeros([len(emb_lens), max(emb_lens)]).to(prompt_masks[0]) |
|
if labels is not None: |
|
labels_new = -100 * torch.ones([len(emb_lens), max(emb_lens)]).to(prompt_ids[0]) |
|
|
|
for i, L in enumerate(emb_lens): |
|
if padding == 'left': |
|
prompt_embeds_new[i, -L:] = prompt_embeds[i] |
|
prompt_ids_new[i, -L:] = prompt_ids[i] |
|
prompt_masks_new[i, -L:] = prompt_masks[i] |
|
if labels is not None: |
|
labels_new[i, -L:] = labels[i] |
|
|
|
elif padding == 'right': |
|
prompt_embeds_new[i, :L] = prompt_embeds[i] |
|
prompt_ids_new[i, :L] = prompt_ids[i] |
|
prompt_masks_new[i, :L] = prompt_masks[i] |
|
if labels is not None: |
|
labels_new[i, :L] = labels[i] |
|
else: |
|
raise ValueError() |
|
|
|
if labels is not None: |
|
return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new |
|
return prompt_embeds_new, prompt_ids_new, prompt_masks_new |
|
|
|
def _insert_media_feat(self, |
|
prompt_embeds, prompt_ids, prompt_masks, |
|
is_languages, |
|
embeds_media, media_token_id, |
|
index_list=None, |
|
labels=None, len_media=None): |
|
|
|
prompt_embeds_new = [] |
|
prompt_masks_new = [] |
|
prompt_ids_new = [] |
|
labels_new = [] |
|
device = embeds_media[0].device |
|
|
|
if index_list is not None: |
|
assert len(index_list) == len(embeds_media) |
|
assert len(embeds_media) <= len(prompt_embeds) |
|
|
|
for b in range(len(prompt_embeds)): |
|
if (index_list is not None) and (b not in index_list): |
|
prompt_embeds_new.append(prompt_embeds[b]) |
|
prompt_ids_new.append(prompt_ids[b]) |
|
prompt_masks_new.append(prompt_masks[b]) |
|
if labels is not None: |
|
labels_new.append(labels[b]) |
|
else: |
|
_idx = prompt_ids[b].tolist().index(media_token_id) |
|
if index_list is not None: |
|
b_media = index_list.index(b) |
|
else: |
|
b_media = b |
|
|
|
if len_media is not None: |
|
cur_embeds_media = embeds_media[b_media, :len_media[b_media]] |
|
else: |
|
cur_embeds_media = embeds_media[b_media] |
|
|
|
prompt_embeds_new.append(torch.cat([prompt_embeds[b][:_idx+1], |
|
cur_embeds_media, |
|
prompt_embeds[b][_idx+1:] |
|
], dim=0)) |
|
prompt_ids_new.append(torch.cat([prompt_ids[b][:_idx+1], |
|
torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100), |
|
prompt_ids[b][_idx+1:] |
|
], dim=0)) |
|
if labels is not None: |
|
labels_new.append(torch.cat([labels[b][:_idx+1], |
|
torch.ones(len(cur_embeds_media), dtype=torch.long).to(device).fill_(-100), |
|
labels[b][_idx+1:] |
|
], dim=0)) |
|
|
|
|
|
prompt_masks_new.append(torch.cat([prompt_masks[b][:_idx+1], |
|
torch.zeros(len(cur_embeds_media), dtype=torch.long).to(device) if is_languages[b] else |
|
torch.ones(len(cur_embeds_media), dtype=torch.long).to(device), |
|
prompt_masks[b][_idx+1:]], dim=0)) |
|
|
|
if labels is not None: |
|
return prompt_embeds_new, prompt_ids_new, prompt_masks_new, labels_new |
|
return prompt_embeds_new, prompt_ids_new, prompt_masks_new |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
samples, |
|
num_beams=5, |
|
max_length=128, |
|
min_length=1, |
|
top_p=0.9, |
|
temperature=0., |
|
return_prompts=False |
|
): |
|
autocast = get_autocast(self.precision, cache_enabled=True) |
|
with autocast(): |
|
conversations = samples['conversations'] |
|
is_languages = [False] * len(conversations) |
|
|
|
image_img = samples.get('images', None) |
|
|
|
index_img = list(range(len(image_img))) |
|
|
|
device = None |
|
special_prefix = ["" for _ in range(len(conversations))] |
|
|
|
if (self.config.encoder_img is not None) and (image_img is not None) and len(index_img) > 0: |
|
for i in index_img: |
|
special_prefix[i] = self.media_token_img + special_prefix[i] |
|
|
|
new_image_img = [] |
|
for index in index_img: |
|
new_image_img.append(image_img[index]) |
|
embeds_img, len_img = self.forward_encoder_img(new_image_img) |
|
device = embeds_img.device |
|
|
|
conv = get_conv_template(self.template_name) |
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
prompts = [] |
|
for i, source in enumerate(conversations): |
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
per_prefix = special_prefix[i] |
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence['from']] |
|
assert role == conv.roles[j % 2], f'{i}' |
|
sentence['value'] = sentence['value'].replace("<image>", "").strip() |
|
|
|
if j == 0: |
|
sentence['value'] = per_prefix + sentence['value'] |
|
|
|
conv.append_message(role, sentence['value']) |
|
prompts.append(conv.get_prompt()) |
|
|
|
self.lm_tokenizer.padding_side = "left" |
|
if self.lm_tokenizer.bos_token is not None: |
|
prompt_text = [self.lm_tokenizer.bos_token + t for t in prompts] |
|
else: |
|
prompt_text = prompts |
|
|
|
prompt_tokens = self.lm_tokenizer( |
|
prompt_text, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=False, |
|
add_special_tokens=False |
|
).to(device) |
|
|
|
|
|
prompt_embeds = self.lm_model.get_input_embeddings()(prompt_tokens.input_ids) |
|
|
|
prompt_masks = prompt_tokens.attention_mask |
|
prompt_ids = prompt_tokens.input_ids |
|
assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left" |
|
|
|
if embeds_img is not None: |
|
prompt_embeds, prompt_ids, prompt_masks = self._insert_media_feat(prompt_embeds=prompt_embeds, |
|
prompt_ids=prompt_ids, |
|
prompt_masks=prompt_masks, |
|
is_languages=is_languages, |
|
embeds_media=embeds_img, |
|
media_token_id=self.media_token_id_img, |
|
index_list=index_img, |
|
len_media=len_img) |
|
|
|
|
|
|
|
prompt_embeds, prompt_ids, prompt_masks = self._concat_embeds(prompt_embeds, prompt_ids, prompt_masks, padding="left") |
|
assert torch.all(prompt_ids[:, -1] != self.lm_tokenizer.pad_token_id), "make sure padding left" |
|
|
|
kwargs = {} |
|
kwargs['max_new_tokens'] = max_length |
|
|
|
outputs = self.lm_model.generate( |
|
|
|
inputs_embeds=prompt_embeds, |
|
attention_mask=prompt_masks, |
|
do_sample=True if temperature > 0 else False, |
|
temperature=temperature, |
|
top_p=top_p, |
|
num_beams=num_beams, |
|
eos_token_id=self.lm_tokenizer.eos_token_id, |
|
|
|
min_length=min_length, |
|
**kwargs |
|
) |
|
output_text = self.lm_tokenizer.batch_decode( |
|
outputs, skip_special_tokens=True |
|
) |
|
output_text = [text.strip() for text in output_text] |
|
|
|
if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): |
|
output_text = self._lemmatize(output_text) |
|
|
|
if return_prompts: |
|
return output_text, prompts |
|
return output_text |
|
|
|
def _lemmatize(self, answers): |
|
def apply(answer): |
|
doc = self.lemmatizer(answer) |
|
|
|
words = [] |
|
for token in doc: |
|
if token.pos_ in ["NOUN", "VERB"]: |
|
words.append(token.lemma_) |
|
else: |
|
words.append(token.text) |
|
answer = " ".join(words) |
|
|
|
return answer |
|
|
|
return [apply(answer) for answer in answers] |
|
|
|
@property |
|
def lemmatizer(self): |
|
if self._lemmatizer is None: |
|
try: |
|
import spacy |
|
self._lemmatizer = spacy.load("en_core_web_sm") |
|
except ImportError: |
|
logging.error( |
|
""" |
|
Please install spacy and en_core_web_sm model to apply lemmatization. |
|
python -m spacy download en_core_web_sm |
|
OR |
|
import spacy.cli |
|
spacy.cli.download("en_core_web_sm") |
|
""" |
|
) |
|
exit(1) |
|
|
|
return self._lemmatizer |
|
|