import gradio as gr import sys import random import os import pandas as pd import torch from torch.utils.data import DataLoader from transformers import AutoTokenizer sys.path.append("/home/user/app/") from scripts.foldseek_util import get_struc_seq from scripts.utils import seed_everything from scripts.models import PLTNUM_PreTrainedModel from scripts.datasets import PLTNUMDataset class Config: batch_size = 2 use_amp = False num_workers = 1 max_length = 512 used_sequence = "left" padding_side = "right" task = "classification" sequence_col = "sequence" # Assuming 'predict_stability' is your function that predicts protein stability def predict_stability(cfg, model_choice, organism_choice, pdb_file=None, sequence=None): # Check if pdb_file is provided if pdb_file: pdb_path = pdb_file.name # Get the path of the uploaded PDB file os.system("chmod 777 bin/foldseek") sequences = get_foldseek_seq(pdb_path) if not sequences: return "Failed to extract sequence from the PDB file." if model_choice == "SaProt": sequence = sequences[2] else: sequence = sequences[0] if organism_choice == "Human": cell_line = "HeLa" else: cell_line = "NIH3T3" # If sequence is provided directly if sequence: cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" cfg.architecture = model_choice cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" output = predict(cfg, sequence) return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {output}..." else: return "No valid input provided." def get_foldseek_seq(pdb_path): parsed_seqs = get_struc_seq( "bin/foldseek", pdb_path, ["A"], process_id=random.randint(0, 10000000), )["A"] return parsed_seqs def predict(cfg, sequence): cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 cfg.device = "cuda" if torch.cuda.is_available() else "cpu" if cfg.used_sequence == "both": cfg.max_length += 1 seed_everything(cfg.seed) df = pd.DataFrame({cfg.sequence_col: [sequence]}) tokenizer = AutoTokenizer.from_pretrained( cfg.model_path, padding_side=cfg.padding_side ) cfg.tokenizer = tokenizer dataset = PLTNUMDataset(cfg, df, train=False) dataloader = DataLoader( dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=False, ) model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) model.to(cfg.device) # predictions = predict_fn(loader, model, cfg) model.eval() predictions = [] for inputs, _ in dataloader: inputs = inputs.to(cfg.device) with torch.no_grad(): with torch.amp.autocast(enabled=cfg.use_amp): preds = ( torch.sigmoid(model(inputs)) if cfg.task == "classification" else model(inputs) ) predictions += preds.cpu().tolist() outputs = {} outputs["raw prediction values"] = predictions outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions] return outputs # Gradio Interface with gr.Blocks() as demo: gr.Markdown( """ # PLTNUM: Protein LifeTime Neural Model **Predict the protein half-life from its sequence or PDB file.** """ ) gr.Image("https://github.com/sagawatatsuya/PLTNUM/blob/main/model-image.png?raw=true", label="Model Image") # Model and Organism selection in the same row to avoid layout issues with gr.Row(): model_choice = gr.Radio( choices=["SaProt", "ESM2"], label="Select PLTNUM's base model.", value="SaProt" ) organism_choice = gr.Radio( choices=["Mouse", "Human"], label="Select the target organism.", value="Mouse" ) with gr.Tabs(): with gr.TabItem("Upload PDB File"): gr.Markdown("### Upload your PDB file:") pdb_file = gr.File(label="Upload PDB File") predict_button = gr.Button("Predict Stability") prediction_output = gr.Textbox(label="Stability Prediction", interactive=False) predict_button.click(fn=predict_stability, inputs=[model_choice, organism_choice, pdb_file], outputs=prediction_output) with gr.TabItem("Enter Protein Sequence"): gr.Markdown("### Enter the protein sequence:") sequence = gr.Textbox( label="Protein Sequence", placeholder="Enter your protein sequence here...", lines=8, ) predict_button = gr.Button("Predict Stability") prediction_output = gr.Textbox(label="Stability Prediction", interactive=False) predict_button.click(fn=predict_stability, inputs=[model_choice, organism_choice, sequence], outputs=prediction_output) gr.Markdown( """ ### How to Use: - **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction. - **Select Organism**: Choose between 'Mouse' or 'Human'. - **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file. - **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence. - **Predict**: Click 'Predict Stability' to receive the prediction. """ ) gr.Markdown( """ ### About the Tool This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input. """ ) demo.launch()