davidkim205's picture
Upload folder using huggingface_hub
577164e verified
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
import torch
from utils.simple_bleu import simple_score
import torch
templates = {
'gemma': {
'stop_words': ['<eos>', ''],
'ko2en': '<bos><start_of_turn>user\nλ‹€μŒ λ¬Έμž₯을 μ˜μ–΄λ‘œ λ²ˆμ—­ν•˜μ„Έμš”.{0}<end_of_turn>\n<start_of_turn>model:',
'en2ko': '<bos><start_of_turn>user\nλ‹€μŒ λ¬Έμž₯을 ν•œκΈ€λ‘œ λ²ˆμ—­ν•˜μ„Έμš”.{0}<end_of_turn>\n<start_of_turn>model:',
'trim_keywords': ['<eos>', ''],
},
'openchat': {
'stop_words': ['<eos>', '<|end_of_turn|>'],
'ko2en': '<s> GPT4 Correct User: λ‹€μŒ λ¬Έμž₯을 μ˜μ–΄λ‘œ λ²ˆμ—­ν•˜μ„Έμš”. {0}<|end_of_turn|> GPT4 Correct Assistant:',
'en2ko': '<s> GPT4 Correct User: λ‹€μŒ λ¬Έμž₯을 ν•œκΈ€λ‘œ λ²ˆμ—­ν•˜μ„Έμš”. {0}<|end_of_turn|> GPT4 Correct Assistant:',
'trim_keywords': ['<eos>', '<|end_of_turn|>'],
},
'qwen': {
'stop_words': ['<eos>', '<|im_end|>'],
'ko2en': '<|im_start|>system \n You are a helpful assistant<|im_end|>\n <|im_start|>λ‹€μŒ λ¬Έμž₯을 μ˜μ–΄λ‘œ λ²ˆμ—­ν•˜μ„Έμš”. \n {0}<|im_end|>\n<|im_start|>assistant\n',
'ko2en': '<|im_start|>system \n You are a helpful assistant<|im_end|>\n <|im_start|>λ‹€μŒ λ¬Έμž₯을 ν•œκΈ€λ‘œ λ²ˆμ—­ν•˜μ„Έμš”. \n {0}<|im_end|>\n<|im_start|>assistant\n',
'trim_keywords': ['<eos>', '<|im_end|>'],
},
#
# <|im_start|>assistant
# "Do you exist?"<|im_end|>
# ]
'davidkim205/iris-7b': {
'stop_words': ['</s>'],
'ko2en': '[INST] λ‹€μŒ λ¬Έμž₯을 μ˜μ–΄λ‘œ λ²ˆμ—­ν•˜μ„Έμš”.{0} [/INST]',
'en2ko': '[INST] λ‹€μŒ λ¬Έμž₯을 ν•œκΈ€λ‘œ λ²ˆμ—­ν•˜μ„Έμš”.{0} [/INST]',
'trim_keywords': ['</s>'],
},
'squarelike/Gugugo-koen-7B-V1.1': {
'stop_words': ['</s>', '</끝>'],
'ko2en': '### ν•œκ΅­μ–΄: {0}</끝>\n### μ˜μ–΄:',
'en2ko': "### μ˜μ–΄: {0}</끝>\n### ν•œκ΅­μ–΄:",
'trim_keywords': ['</s>', '</끝>'],
},
'maywell/Synatra-7B-v0.3-Translation': {
'stop_words': ['</s>', '</끝>', '<|im_end|>'],
'ko2en': '<|im_start|>system\n주어진 λ¬Έμž₯을 μ˜μ–΄λ‘œ λ²ˆμ—­ν•΄λΌ.<|im_end|>\n<|im_start|>user\n{0}<|im_end|>\n<|im_start|>assistant',
'en2ko': '<|im_start|>system\n주어진 λ¬Έμž₯을 ν•œκ΅­μ–΄λ‘œ λ²ˆμ—­ν•΄λΌ.<|im_end|>\n<|im_start|>user\n{0}<|im_end|>\n<|im_start|>assistant',
'trim_keywords': ['<|im_end|>'],
},
'Unbabel/TowerInstruct-7B-v0.1': {
'stop_words': ['</s>', '</끝>', '<|im_end|>'],
'ko2en': '<|im_start|>user\nTranslate the following text from English into Korean.\nKorean: {0}\nEnglish:<|im_end|>\n<|im_start|>assistant',
'en2ko': '<|im_start|>user\nTranslate the following text from Korean into English.\nEnglish: {0}\nKorean:<|im_end|>\n<|im_start|>assistant',
'trim_keywords': ['<|im_end|>'],
},
}
model_info = {'model': None, 'tokenizer': None, 'stopping_criteria': None}
class LocalStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_words=[]):
super().__init__()
stops = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for
stop_word in stop_words]
print('stop_words', stop_words)
print('stop_words_ids', stops)
self.stop_words = stop_words
self.stops = [stop.cuda() for stop in stops]
self.tokenizer = tokenizer
def _compare_token(self, input_ids):
for stop in self.stops:
if len(stop.size()) != 1:
continue
stop_len = len(stop)
if torch.all((stop == input_ids[0][-stop_len:])).item():
return True
return False
def _compare_decode(self, input_ids):
input_str = self.tokenizer.decode(input_ids[0])
for stop_word in self.stop_words:
if input_str.endswith(stop_word):
return True
return False
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
return self._compare_decode(input_ids)
def trim_sentence(sentence, keywords):
for keyword in keywords:
if keyword in sentence:
# ν‚€μ›Œλ“œλ₯Ό 찾은 경우, ν•΄λ‹Ή 인덱슀λ₯Ό κΈ°μ€€μœΌλ‘œ λ¬Έμž₯을 μžλ¦„
index = sentence.find(keyword)
trimmed_sentence = sentence[:index]
sentence = trimmed_sentence.strip() # 쒌우 곡백 제거 ν›„ λ°˜ν™˜
return sentence
def load_model(path, template_name=None):
global model_info
print('load_model', path)
if template_name == None:
template_name = path
if templates.get(template_name) == None:
template_name = 'davidkim205/iris-7b'
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(path)
model_info['model'] = model
model_info['tokenizer'] = tokenizer
model_info['template'] = templates[template_name]
stop_words = templates[template_name]['stop_words']
stopping_criteria = StoppingCriteriaList([LocalStoppingCriteria(tokenizer=tokenizer, stop_words=stop_words)])
model_info['stopping_criteria'] = stopping_criteria
def generate(prompt):
global model_info
if model_info['model'] == None:
print('model is null, load the model first.')
return ''
model = model_info['model']
tokenizer = model_info['tokenizer']
stopping_criteria = model_info['stopping_criteria']
encoding = tokenizer(
prompt,
return_tensors='pt',
return_token_type_ids=False
).to("cuda")
gen_tokens = model.generate(
**encoding,
max_new_tokens=2048,
temperature=1.0,
num_beams=5,
stopping_criteria=stopping_criteria
)
prompt_end_size = encoding.input_ids.shape[1]
result = tokenizer.decode(gen_tokens[0, prompt_end_size:])
result = trim_sentence(result, model_info['template']['trim_keywords'])
return result
def translate_ko2en(text):
global model_info
prompt = model_info['template']['ko2en'].format(text)
return generate(prompt)
def translate_en2ko(text):
global model_info
prompt = model_info['template']['en2ko'].format(text)
return generate(prompt)
def main():
load_model("davidkim205/iris-7b")
# load_model("squarelike/Gugugo-koen-7B-V1.1")
# load_model("maywell/Synatra-7B-v0.3-Translation")
# load_model("Unbabel/TowerInstruct-7B-v0.1")
while True:
text = input('>')
en_text = translate_ko2en(text)
ko_text = translate_en2ko(en_text)
print('------------------')
print('en_text', en_text)
print('ko_text', ko_text)
print('score', simple_score(text, ko_text))
if __name__ == "__main__":
main()