|
""" |
|
ํ์ ์ฐพ๋ ๋ชจ๋ธ ์ ํธ ํ์ผ๋ค |
|
""" |
|
class InputFeatures: |
|
""" |
|
BERT ๋ชจ๋ธ์ ์
๋ ฅ๋ค |
|
""" |
|
def __init__(self, tokens, input_ids, input_mask, input_type_ids): |
|
self.tokens = tokens |
|
self.input_ids = input_ids |
|
self.input_mask = input_mask |
|
self.input_type_ids = input_type_ids |
|
|
|
|
|
def convert_examples_to_features(examples, tokenizer): |
|
""" |
|
ํ
์คํธ segment๋ฅผ ๋จ์ด ID๋ก ๋ณํํฉ๋๋ค. |
|
""" |
|
features = [] |
|
tokens_list = [] |
|
|
|
for (ex_index, example) in enumerate(examples): |
|
tokens = tokenizer.tokenize(example) |
|
tokens_list.append(tokens) |
|
|
|
new_tokens = [] |
|
input_type_ids = [] |
|
|
|
new_tokens.append("[CLS]") |
|
input_type_ids.append(0) |
|
new_tokens = new_tokens + tokens |
|
input_type_ids = input_type_ids + [0] * len(tokens) |
|
new_tokens.append("[SEP]") |
|
input_type_ids.append(0) |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(new_tokens) |
|
input_mask = [1] * len(input_ids) |
|
|
|
features.append( |
|
InputFeatures( |
|
tokens=new_tokens, |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
input_type_ids=input_type_ids)) |
|
|
|
return features, tokens_list |
|
|
|
|
|
def get_alias2id(name_list_path) -> dict: |
|
""" |
|
์ฃผ์ด์ง ์ด๋ฆ ๋ชฉ๋ก ํ์ผ์์ ๋ณ์นญ(alias)์ ID๋ก ๋งคํํ๋ ์ฌ์ ์ ์์ฑ. |
|
""" |
|
with open(name_list_path, 'r', encoding='utf-8') as fin: |
|
name_lines = fin.readlines() |
|
alias2id = {} |
|
|
|
for i, line in enumerate(name_lines): |
|
for alias in line.strip().split()[1:]: |
|
alias2id[alias] = i |
|
|
|
return alias2id |
|
|
|
|
|
def find_speak(fs_model, input_data, tokenizer, alias2id): |
|
""" |
|
์ฃผ์ด์ง ๋ชจ๋ธ๊ณผ ์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ์
๋ ฅ์ ๋ํ ํ์๋ฅผ ์ฐพ๋ ํจ์ |
|
""" |
|
model = fs_model |
|
check_data_iter = iter(input_data) |
|
|
|
names = [] |
|
|
|
for _ in range(len(input_data)): |
|
|
|
seg_sents, css, scl, mp, qi, cut_css, name_list_index = next(check_data_iter) |
|
features, tokens_list = convert_examples_to_features(examples=css, tokenizer=tokenizer) |
|
|
|
try: |
|
predictions = model(features, scl, mp, qi, 0, "cuda:0", tokens_list, cut_css) |
|
except RuntimeError: |
|
predictions = model(features, scl, mp, qi, 0, "cpu", tokens_list, cut_css) |
|
|
|
scores, _, _ = predictions |
|
|
|
|
|
try: |
|
scores_np = scores.detach().cpu().numpy() |
|
scores_list = scores_np.tolist() |
|
score_index = scores_list.index(max(scores_list)) |
|
name_index = name_list_index[score_index] |
|
|
|
for key, val in alias2id.items(): |
|
if val == name_index: |
|
result_key = key |
|
|
|
names.append(result_key) |
|
except AttributeError: |
|
names.append('์ ์ ์์') |
|
|
|
return names |
|
|
|
|
|
def making_script(text, speaker:list, instance_num:list) -> str: |
|
""" |
|
์ฃผ์ด์ง ํ
์คํธ์ ํ์ ๋ชฉ๋ก, ํด๋นํ๋ ์ค ๋ฒํธ๋ฅผ ์ฌ์ฉํ์ฌ ๋ํ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ๋ ํจ์ |
|
""" |
|
lines = text.splitlines() |
|
for num, people in zip(instance_num, speaker): |
|
lines[num] = f'{people}: {lines[num]}' |
|
return lines |
|
|