InfMLLM2_7B_chat / processors_conv.py
QianYEee's picture
Upload 18 files
8a096e8 verified
import io
IGNORE_TOKEN_ID = -100
from typing import Dict
import torch
import torchvision.transforms as T
import transformers
from .conversation import get_conv_template
from PIL import Image
from torch.utils.data import ConcatDataset, WeightedRandomSampler
import sys
def preprocess_qwen(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
special_prefixs,
text_only: bool = False,
group_by_length: bool = False,
ds_name: str = None
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
assert len(sources) == len(special_prefixs)
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
per_prefix = special_prefixs[i]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
sentence['value'] = sentence['value'].replace("<image>", "").strip() # llava-1.5 add <image> to the begin of the question, remove here
sentence['value'] = sentence['value'].replace("<video>", "").strip()
if j == 0:
sentence['value'] = per_prefix + sentence['value']
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if tokenizer.bos_token is not None:
new_conversations = []
for conversation in conversations:
conversation = tokenizer.bos_token + conversation
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
tokenizer.padding_side = 'right'
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length else 'longest',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + '\n' + conv.roles[1] + '\n' # <|im_end|>\n<|im_start|>assistant\n
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
sep2 = conv.sep + '\n'
turns = conversation.split(sep2)
re_turns = [sep2.join(turns[:3])+sep2] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(sep2.join(turns[conv_idx:conv_idx + 2])+sep2) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_TOKEN_ID
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
target[target == endoftext_id] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
#print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
#print(f'[input_id {i}]', input_ids[:, cur_len: cur_len + instruction_len])
#print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
#print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
f'conversation: {conversation}'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
conversations=conversations
)