PLTNUM / scripts /augmentation.py
sagawa's picture
Upload 17 files
4321e7e verified
raw
history blame
1.96 kB
import random
def random_change_augmentation(aas, cfg):
residue_tokens = ("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y")
stracture_aware_tokens = ("a", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "y")
length = len(aas)
swap_indices = random.sample(
range(length), int(length * cfg.random_change_ratio)
)
new_aas = ""
for i, aa in enumerate(aas):
if i in swap_indices:
if aas[i] in residue_tokens:
new_aas += random.choice(residue_tokens)
elif aas[i] in stracture_aware_tokens:
new_aas += random.choice(stracture_aware_tokens)
else:
new_aas += aa
return new_aas
def mask_augmentation(aas, cfg):
length = len(aas)
swap_indices = random.sample(
range(0, length // cfg.token_length),
int(length // cfg.token_length * cfg.mask_ratio),
)
for ith in swap_indices:
aas = (
aas[: ith * cfg.token_length]
+ "@" * cfg.token_length
+ aas[(ith + 1) * cfg.token_length :]
)
aas = aas.replace("@@", "<mask>").replace("@", "<mask>")
return aas
def random_delete_augmentation(aas, cfg):
length = len(aas)
swap_indices = random.sample(
range(0, length // cfg.token_length),
int(length // cfg.token_length * cfg.random_delete_ratio),
)
for ith in swap_indices:
aas = (
aas[: ith * cfg.token_length]
+ "@" * cfg.token_length
+ aas[(ith + 1) * cfg.token_length :]
)
aas = aas.replace("@@", "").replace("@", "")
return aas
def truncate_augmentation(aas, cfg):
length = len(aas)
if length > cfg.max_length:
diff = length - cfg.max_length
start = random.randint(0, diff)
return aas[start : start + cfg.max_length]
else:
return aas