import gradio as gr
import sys
import random
import os
import pandas as pd
import torch
import itertools
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
sys.path.append("scripts/")
from foldseek_util import get_struc_seq
from utils import seed_everything
from models import PLTNUM_PreTrainedModel
from datasets_ import PLTNUMDataset
class Config:
def __init__(self):
self.batch_size = 2
self.use_amp = False
self.num_workers = 1
self.max_length = 512
self.used_sequence = "left"
self.padding_side = "right"
self.task = "classification"
self.sequence_col = "sequence"
self.seed = 42
def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
try:
results = []
for pdb_file in pdb_files:
pdb_path = pdb_file.name
os.system("chmod 777 bin/foldseek")
sequences = get_foldseek_seq(pdb_path)
if not sequences:
results.append(f"Failed to extract sequence from {pdb_file.name}.")
continue
sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
prediction = predict_stability_core(model_choice, organism_choice, sequence, cfg)
results.append(f"Prediction for {pdb_file.name}: {prediction}")
return "
".join(results)
except Exception as e:
return f"An error occurred: {str(e)}"
def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()):
try:
if not sequence:
return "No valid sequence provided."
return predict_stability_core(model_choice, organism_choice, sequence, cfg)
except Exception as e:
return f"An error occurred: {str(e)}"
def predict_stability_core(model_choice, organism_choice, sequence, cfg=Config()):
cell_line = "HeLa" if organism_choice == "Human" else "NIH3T3"
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
cfg.architecture = model_choice
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
output = predict(cfg, sequence)
return output
def get_foldseek_seq(pdb_path):
parsed_seqs = get_struc_seq(
"bin/foldseek",
pdb_path,
["A"],
process_id=random.randint(0, 10000000),
)["A"]
return parsed_seqs
def predict(cfg, sequence):
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
if cfg.used_sequence == "both":
cfg.max_length += 1
seed_everything(cfg.seed)
df = pd.DataFrame({cfg.sequence_col: [sequence]})
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_path, padding_side=cfg.padding_side
)
cfg.tokenizer = tokenizer
dataset = PLTNUMDataset(cfg, df, train=False)
dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=False,
)
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
model.to(cfg.device)
model.eval()
predictions = []
for inputs, _ in dataloader:
inputs = inputs.to(cfg.device)
with torch.no_grad():
with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
preds = (
torch.sigmoid(model(inputs))
if cfg.task == "classification"
else model(inputs)
)
predictions += preds.cpu().tolist()
predictions = list(itertools.chain.from_iterable(predictions))
outputs = {
"raw prediction values": predictions,
"binary prediction values": [1 if x > 0.5 else 0 for x in predictions]
}
html_output = f"""
Raw prediction value: {outputs['raw prediction values'][0]}
Binary prediction values: {outputs['binary prediction values'][0]}