import numpy as np import torch from torch.nn import CrossEntropyLoss import itertools def is_hiragana_or_katakana(s): for char in s: if not ('\u3040' <= char <= '\u309F' or '\u30A0' <= char <= '\u30FF') or char == "ー": return False return True def add_dakuten_handakuten(query, string_type): def convert_to_hiragana(s): """与えられた文字列を平仮名に変換する""" result = [] for char in s: if 'ァ' <= char <= 'ヶ': # 片仮名を平仮名に変換 result.append(chr(ord(char) - 96)) else: result.append(char) return ''.join(result) def convert_to_katakana(s): """与えられた文字列を片仮名に変換する""" result = [] for char in s: if 'ぁ' <= char <= 'ゖ': # 平仮名を片仮名に変換 result.append(chr(ord(char) + 96)) else: result.append(char) return ''.join(result) if string_type == "hiragana": s = convert_to_hiragana(query) dakuon_map = { 'か': 'が', 'き': 'ぎ', 'く': 'ぐ', 'け': 'げ', 'こ': 'ご', 'さ': 'ざ', 'し': 'じ', 'す': 'ず', 'せ': 'ぜ', 'そ': 'ぞ', 'た': 'だ', 'ち': 'ぢ', 'つ': 'づ', 'て': 'で', 'と': 'ど', 'は': 'ば', 'ひ': 'び', 'ふ': 'ぶ', 'へ': 'べ', 'ほ': 'ぼ' } handakuon_map = { 'は': 'ぱ', 'ひ': 'ぴ', 'ふ': 'ぷ', 'へ': 'ぺ', 'ほ': 'ぽ' } elif string_type == "katakana": s = convert_to_katakana(query) dakuon_map = { 'カ': 'ガ', 'キ': 'ギ', 'ク': 'グ', 'ケ': 'ゲ', 'コ': 'ゴ', 'サ': 'ザ', 'シ': 'ジ', 'ス': 'ズ', 'セ': 'ゼ', 'ソ': 'ゾ', 'タ': 'ダ', 'チ': 'ヂ', 'ツ': 'ヅ', 'テ': 'デ', 'ト': 'ド', 'ハ': 'バ', 'ヒ': 'ビ', 'フ': 'ブ', 'ヘ': 'ベ', 'ホ': 'ボ', 'ウ': 'ヴ' } handakuon_map = { 'ハ': 'パ', 'ヒ': 'ピ', 'フ': 'プ', 'ヘ': 'ペ', 'ホ': 'ポ' } # 文字ごとに元の文字と濁音・半濁音をリストにする options = [] for char in s: temp = [char] if char in dakuon_map: temp.append(dakuon_map[char]) if char in handakuon_map: temp.append(handakuon_map[char]) options.append(temp) # 全ての組み合わせを生成 candidates = list(itertools.product(*options)) return candidates def add_dashes(s): if not s: return [''] # 再帰的に文字列の先頭以外の部分に「ー」を挿入するパターンを取得 substr_patterns = add_dashes(s[1:]) # 現在の文字を含めたパターンを生成 result = [] for pattern in substr_patterns: result.append(s[0] + pattern) # そのまま連結 result.append(s[0] + 'ー' + pattern) # 「ー」を挿入して連結 return result def compute_losses(candidates, model, tokenizer): inputs = tokenizer(candidates, return_tensors="pt", padding=True) inputs["labels"] = inputs["input_ids"].masked_fill(inputs["input_ids"] == tokenizer.pad_token_id, -100) inputs = inputs.to(model.device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits labels = inputs["labels"] shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(reduction="none") losses_flat = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1)) losses_seq = losses_flat.view(shift_labels.shape) mask_labels = shift_labels != tokenizer.pad_token_id losses = torch.sum(losses_seq * mask_labels, -1) / mask_labels.sum(-1) return losses def search_candidates(query, query_candidates, model, tokenizer, top_k=100): old_query = query[:-1] if old_query not in query_candidates: old_candidates, _ = search_candidates(old_query, query_candidates, model=model, tokenizer=tokenizer, top_k=top_k) else: old_candidates, _ = query_candidates[old_query] string = query[-1] candidates = [] for string_type in ["hiragana", "katakana"]: candidates_ = add_dakuten_handakuten(string, string_type=string_type) for candidate_ in candidates_: candidates += add_dashes(candidate_) combinations = itertools.product(old_candidates, candidates) new_candidates = [''.join(pair) for pair in combinations] losses = compute_losses(new_candidates, model=model, tokenizer=tokenizer) sorted_items = torch.sort(losses) sorted_candidates = np.array(new_candidates)[sorted_items.indices.cpu().numpy()] topk_candidates = sorted_candidates[:top_k].tolist() topk_losses = sorted_items.values[:top_k].cpu().tolist() query_candidates[query] = (topk_candidates, topk_losses) return topk_candidates, topk_losses