|
import io |
|
|
|
IGNORE_TOKEN_ID = -100 |
|
from typing import Dict |
|
|
|
import torch |
|
import torchvision.transforms as T |
|
import transformers |
|
from .conversation import get_conv_template |
|
from PIL import Image |
|
from torch.utils.data import ConcatDataset, WeightedRandomSampler |
|
import sys |
|
|
|
|
|
def preprocess_qwen( |
|
template_name, |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
special_prefixs, |
|
text_only: bool = False, |
|
group_by_length: bool = False, |
|
ds_name: str = None |
|
) -> Dict: |
|
conv = get_conv_template(template_name) |
|
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} |
|
|
|
assert len(sources) == len(special_prefixs) |
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]['from']] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
per_prefix = special_prefixs[i] |
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence['from']] |
|
assert role == conv.roles[j % 2], f'{i}' |
|
sentence['value'] = sentence['value'].replace("<image>", "").strip() |
|
sentence['value'] = sentence['value'].replace("<video>", "").strip() |
|
|
|
if j == 0: |
|
sentence['value'] = per_prefix + sentence['value'] |
|
|
|
conv.append_message(role, sentence['value']) |
|
conversations.append(conv.get_prompt()) |
|
|
|
if tokenizer.bos_token is not None: |
|
new_conversations = [] |
|
for conversation in conversations: |
|
conversation = tokenizer.bos_token + conversation |
|
new_conversations.append(conversation) |
|
conversations = new_conversations |
|
|
|
|
|
tokenizer.padding_side = 'right' |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors='pt', |
|
padding=False if group_by_length else 'longest', |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
targets = input_ids.clone() |
|
|
|
|
|
sep = conv.sep + '\n' + conv.roles[1] + '\n' |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) |
|
|
|
sep2 = conv.sep + '\n' |
|
turns = conversation.split(sep2) |
|
re_turns = [sep2.join(turns[:3])+sep2] |
|
for conv_idx in range(3, len(turns), 2): |
|
re_turns.append(sep2.join(turns[conv_idx:conv_idx + 2])+sep2) |
|
cur_len = 0 |
|
target[:cur_len] = IGNORE_TOKEN_ID |
|
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') |
|
target[target == endoftext_id] = IGNORE_TOKEN_ID |
|
|
|
for i, turn in enumerate(re_turns): |
|
if turn == '': |
|
break |
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
|
parts = turn.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID |
|
|
|
|
|
|
|
|
|
cur_len += turn_len |
|
|
|
target[cur_len:] = IGNORE_TOKEN_ID |
|
|
|
if False: |
|
z = target.clone() |
|
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) |
|
print(repr(tokenizer.decode(z))) |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_TOKEN_ID |
|
print( |
|
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' |
|
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' |
|
f'conversation: {conversation}' |
|
) |
|
sys.stdout.flush() |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
attention_mask=input_ids.ne(tokenizer.pad_token_id), |
|
conversations=conversations |
|
) |
|
|
|
|
|
|