""" 화자 찾는 모델 유틸 파일들 """ class InputFeatures: """ BERT 모델의 입력들 """ def __init__(self, tokens, input_ids, input_mask, input_type_ids): self.tokens = tokens self.input_ids = input_ids self.input_mask = input_mask self.input_type_ids = input_type_ids def convert_examples_to_features(examples, tokenizer): """ 텍스트 segment를 단어 ID로 변환합니다. """ features = [] tokens_list = [] for (ex_index, example) in enumerate(examples): tokens = tokenizer.tokenize(example) tokens_list.append(tokens) new_tokens = [] input_type_ids = [] new_tokens.append("[CLS]") input_type_ids.append(0) new_tokens = new_tokens + tokens input_type_ids = input_type_ids + [0] * len(tokens) new_tokens.append("[SEP]") input_type_ids.append(0) input_ids = tokenizer.convert_tokens_to_ids(new_tokens) input_mask = [1] * len(input_ids) features.append( InputFeatures( tokens=new_tokens, input_ids=input_ids, input_mask=input_mask, input_type_ids=input_type_ids)) return features, tokens_list def get_alias2id(name_list_path) -> dict: """ 주어진 이름 목록 파일에서 별칭(alias)을 ID로 매핑하는 사전을 생성. """ with open(name_list_path, 'r', encoding='utf-8') as fin: name_lines = fin.readlines() alias2id = {} for i, line in enumerate(name_lines): for alias in line.strip().split()[1:]: alias2id[alias] = i return alias2id def find_speak(fs_model, input_data, tokenizer, alias2id): """ 주어진 모델과 입력 데이터를 사용하여 각 입력에 대한 화자를 찾는 함수 """ model = fs_model check_data_iter = iter(input_data) names = [] for _ in range(len(input_data)): seg_sents, css, scl, mp, qi, cut_css, name_list_index = next(check_data_iter) features, tokens_list = convert_examples_to_features(examples=css, tokenizer=tokenizer) try: predictions = model(features, scl, mp, qi, 0, "cuda:0", tokens_list, cut_css) except RuntimeError: predictions = model(features, scl, mp, qi, 0, "cpu", tokens_list, cut_css) scores, _, _ = predictions # 후처리 try: scores_np = scores.detach().cpu().numpy() scores_list = scores_np.tolist() score_index = scores_list.index(max(scores_list)) name_index = name_list_index[score_index] for key, val in alias2id.items(): if val == name_index: result_key = key names.append(result_key) except AttributeError: names.append('알 수 없음') return names def making_script(text, speaker:list, instance_num:list) -> str: """ 주어진 텍스트와 화자 목록, 해당하는 줄 번호를 사용하여 대화 스크립트를 생성하는 함수 """ lines = text.splitlines() for num, people in zip(instance_num, speaker): lines[num] = f'{people}: {lines[num]}' return lines