Spaces:
Configuration error
Configuration error
File size: 6,325 Bytes
b78b52f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from transformers import LlamaTokenizer
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
import sentencepiece as spm
import argparse
def is_chinese(uchar):
"""判断一个unicode是否是汉字"""
return '\u4e00' <= uchar <= '\u9fa5'
def is_chinese_string(string):
"""判断是否全为汉字"""
return all(is_chinese(c) for c in string)
def load_baichuan_vocab(vocab_file):
words = set()
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
words.add(line.strip().split()[0])
return words
def load_jieba_vocab(jieba_vocab_file):
# Read jieba vocab and sort by freq
with open(jieba_vocab_file, "r", encoding="utf-8") as f:
lines = f.readlines()
word_freqs = [line.strip().split() for line in lines]
word_freqs.sort(key=lambda x: int(x[1]), reverse=True)
return word_freqs
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
parser.add_argument('--jieba_word_size', default=20000, type=int)
args = parser.parse_args()
print(args)
# load
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load(args.domain_sp_model_file)
llama_spm = sp_pb2_model.ModelProto()
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
chinese_spm = sp_pb2_model.ModelProto()
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())
# print number of tokens
print(len(llama_tokenizer), len(chinese_sp_model))
print(llama_tokenizer.all_special_tokens)
print(llama_tokenizer.all_special_ids)
print(llama_tokenizer.special_tokens_map)
# Add Chinese tokens to LLaMA tokenizer
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
print(len(llama_spm_tokens_set))
print(f"Before:{len(llama_spm_tokens_set)}")
added_set = set()
for p in chinese_spm.pieces:
piece = p.piece
if piece not in llama_spm_tokens_set:
# print('picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
added_set.add(piece)
print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}")
vocab = load_baichuan_vocab(args.baichuan_vocab_file)
print('baichuan vocab len:', len(vocab))
baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)])
print('baichuan chinese vocab size:', len(baichuan_vocab_set))
print('baichuan vocab head:', list(baichuan_vocab_set)[:10])
for p in baichuan_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('baichuan picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
added_set.add(piece)
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")
if args.add_jieba:
word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
top_words = word_freqs[:args.jieba_word_size]
print('jieba top10 freq words:', top_words[:10])
jieba_vocab_set = set([i[0] for i in top_words if i])
print('jieba_vocab_set size:', len(jieba_vocab_set))
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
for p in jieba_vocab_set:
piece = p
if piece not in llama_spm_tokens_set and piece not in added_set:
# print('jieba picec', piece)
new_p = sp_pb2_model.ModelProto().SentencePiece()
new_p.piece = piece
new_p.score = 0
llama_spm.pieces.append(new_p)
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
# Save
output_sp_dir = 'merged_tokenizer_sp'
output_hf_dir = 'merged_tokenizer_hf' # the path to save Chinese-LLaMA tokenizer
os.makedirs(output_sp_dir, exist_ok=True)
with open(output_sp_dir + '/chinese_llama.model', 'wb') as f:
f.write(llama_spm.SerializeToString())
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model')
tokenizer.save_pretrained(output_hf_dir)
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")
# Test
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
print(chinese_llama_tokenizer.all_special_tokens)
print(chinese_llama_tokenizer.all_special_ids)
print(chinese_llama_tokenizer.special_tokens_map)
print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer))
text = '''this is a test, hello world. thisisatesthelloworld,
慕容复来到河边,姑苏慕容氏在外面丢了人。
1号店一周岁了,我们一古脑儿买了10斤零食。
巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。
白日依山尽,黄河入海流。欲穷千里目,更上一层楼。'''
print("Test text:\n", text)
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
if __name__ == '__main__':
main()
|