import random def random_change_augmentation(aas, cfg): residue_tokens = ("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y") stracture_aware_tokens = ("a", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "y") length = len(aas) swap_indices = random.sample( range(length), int(length * cfg.random_change_ratio) ) new_aas = "" for i, aa in enumerate(aas): if i in swap_indices: if aas[i] in residue_tokens: new_aas += random.choice(residue_tokens) elif aas[i] in stracture_aware_tokens: new_aas += random.choice(stracture_aware_tokens) else: new_aas += aa return new_aas def mask_augmentation(aas, cfg): length = len(aas) swap_indices = random.sample( range(0, length // cfg.token_length), int(length // cfg.token_length * cfg.mask_ratio), ) for ith in swap_indices: aas = ( aas[: ith * cfg.token_length] + "@" * cfg.token_length + aas[(ith + 1) * cfg.token_length :] ) aas = aas.replace("@@", "").replace("@", "") return aas def random_delete_augmentation(aas, cfg): length = len(aas) swap_indices = random.sample( range(0, length // cfg.token_length), int(length // cfg.token_length * cfg.random_delete_ratio), ) for ith in swap_indices: aas = ( aas[: ith * cfg.token_length] + "@" * cfg.token_length + aas[(ith + 1) * cfg.token_length :] ) aas = aas.replace("@@", "").replace("@", "") return aas def truncate_augmentation(aas, cfg): length = len(aas) if length > cfg.max_length: diff = length - cfg.max_length start = random.randint(0, diff) return aas[start : start + cfg.max_length] else: return aas