Spaces:
Runtime error
Runtime error
""" | |
Input UI for RoseTTAfold All Atom | |
using two custom gradio components: gradio_molecule3d and gradio_cofoldinginput | |
""" | |
import gradio as gr | |
from gradio_cofoldinginput import CofoldingInput | |
from gradio_molecule3d import Molecule3D | |
import json | |
import yaml | |
from openbabel import openbabel | |
import zipfile | |
import tempfile | |
import os | |
from Bio.PDB import PDBParser, PDBIO | |
baseconfig = """job_name: "structure_prediction" | |
output_path: "" | |
checkpoint_path: RFAA_paper_weights.pt | |
database_params: | |
sequencedb: "" | |
hhdb: "pdb100_2021Mar03/pdb100_2021Mar03" | |
command: make_msa.sh | |
num_cpus: 4 | |
mem: 64 | |
protein_inputs: null | |
na_inputs: null | |
sm_inputs: null | |
covale_inputs: null | |
residue_replacement: null | |
chem_params: | |
use_phospate_frames_for_NA: True | |
use_cif_ordering_for_trp: True | |
loader_params: | |
n_templ: 4 | |
MAXLAT: 128 | |
MAXSEQ: 1024 | |
MAXCYCLE: 4 | |
BLACK_HOLE_INIT: False | |
seqid: 150.0 | |
legacy_model_param: | |
n_extra_block: 4 | |
n_main_block: 32 | |
n_ref_block: 4 | |
n_finetune_block: 0 | |
d_msa: 256 | |
d_msa_full: 64 | |
d_pair: 192 | |
d_templ: 64 | |
n_head_msa: 8 | |
n_head_pair: 6 | |
n_head_templ: 4 | |
d_hidden_templ: 64 | |
p_drop: 0.0 | |
use_chiral_l1: True | |
use_lj_l1: True | |
use_atom_frames: True | |
recycling_type: "all" | |
use_same_chain: True | |
lj_lin: 0.75 | |
SE3_param: | |
num_layers: 1 | |
num_channels: 32 | |
num_degrees: 2 | |
l0_in_features: 64 | |
l0_out_features: 64 | |
l1_in_features: 3 | |
l1_out_features: 2 | |
num_edge_features: 64 | |
n_heads: 4 | |
div: 4 | |
SE3_ref_param: | |
num_layers: 2 | |
num_channels: 32 | |
num_degrees: 2 | |
l0_in_features: 64 | |
l0_out_features: 64 | |
l1_in_features: 3 | |
l1_out_features: 2 | |
num_edge_features: 64 | |
n_heads: 4 | |
div: 4 | |
""" | |
def convert_format(input_file, jobname, chain, deleteIndexes, attachmentIndex): | |
conv = openbabel.OBConversion() | |
conv.SetInAndOutFormats('cdjson', 'sdf') | |
# Add options | |
conv.AddOption("c", openbabel.OBConversion.OUTOPTIONS, "1") | |
with open(f"{jobname}_sm_{chain}.json", "w+") as fp: | |
fp.write(input_file) | |
mol = openbabel.OBMol() | |
conv.ReadFile(mol, f"{jobname}_sm_{chain}.json") | |
deleted_count = 0 | |
# delete atoms in delete indexes | |
for index in sorted(deleteIndexes, reverse=True): | |
if index < attachmentIndex: | |
deleted_count += 1 | |
atom = mol.GetAtom(index) | |
mol.DeleteAtom(atom) | |
attachmentIndex -= deleted_count | |
conv.WriteFile(mol, f"{jobname}_sm_{chain}.sdf") | |
return attachmentIndex | |
def prepare_input(input, jobname, baseconfig, hard_case): | |
input_categories = {"protein":"protein_inputs", "DNA":"na_inputs","RNA":"na_inputs", "ligand":"sm_inputs"} | |
# convert input to yaml format | |
yaml_dict = {"defaults":["base"], "job_name":jobname, "output_path": jobname} | |
list_of_input_files = [] | |
if len(input["chains"]) == 0: | |
raise gr.Error("At least one chain must be provided") | |
for chain in input["chains"]: | |
if input_categories[chain["class"]] not in yaml_dict.keys(): | |
yaml_dict[input_categories[chain["class"]]] = {} | |
if input_categories[chain["class"]] in ["protein_inputs", "na_inputs"]: | |
#write fasta | |
with open(f"{jobname}_{chain['chain']}.fasta", "w+") as fp: | |
fp.write(f">chain A\n{chain['sequence']}") | |
if input_categories[chain["class"]] == "na_inputs": | |
entry = {"input_type":chain["class"].lower(), "fasta":f"{jobname}/{jobname}_{chain['chain']}.fasta"} | |
else: | |
entry = {"fasta_file": f"{jobname}/{jobname}_{chain['chain']}.fasta"} | |
list_of_input_files.append(f"{jobname}_{chain['chain']}.fasta") | |
yaml_dict[input_categories[chain["class"]]][chain['chain']] = entry | |
if input_categories[chain['class']] == "sm_inputs": | |
if "smiles" in chain.keys(): | |
entry = {"input_type": "smiles", "input": chain["smiles"]} | |
elif "sdf" in chain.keys(): | |
# write to file | |
with open(f"{jobname}_sm_{chain['chain']}.sdf", "w+") as fp: | |
fp.write(chain["sdf"]) | |
list_of_input_files.append(f"{jobname}_sm_{chain['chain']}.sdf") | |
entry = {"input_type": "sdf", "input": f"{jobname}/{jobname}_sm_{chain['chain']}.sdf"} | |
elif "name" in chain.keys(): | |
list_of_input_files.append(f"metal_sdf/{chain['name']}_ideal.sdf") | |
entry = {"input_type": "sdf", "input": f"{jobname}/{chain['name']}_ideal.sdf"} | |
yaml_dict["sm_inputs"][chain['chain']] = entry | |
covale_inputs = [] | |
if len(input["covMods"])>0: | |
yaml_dict["covale_inputs"]="" | |
for covMod in input["covMods"]: | |
new_attachment_index = covMod["attachmentIndex"] | |
if len(covMod["deleteIndexes"])>0: | |
new_attachment_index = convert_format(covMod["mol"],jobname, covMod["ligand"], covMod["deleteIndexes"], covMod["attachmentIndex"]) | |
chirality_ligand = "null" | |
chirality_protein = "null" | |
if covMod["protein_symmetry"] in ["CW", "CCW"]: | |
chirality_protein = covMod["protein_symmetry"] | |
if covMod["ligand_symmetry"] in ["CW", "CCW"]: | |
chirality_ligand = covMod["ligand_symmetry"] | |
covale_inputs.append(((covMod[ "protein"], covMod["residue"], covMod["atom"]), (covMod["ligand"], new_attachment_index), (chirality_protein, chirality_ligand))) | |
if len(input["covMods"])>0: | |
yaml_dict["covale_inputs"] = json.dumps(json.dumps(covale_inputs))[1:-1].replace("'", "\"") | |
if hard_case: | |
yaml_dict["loader_params"]= {} | |
yaml_dict["loader_params"]["MAXCYCLE"] = 10 | |
# write yaml to tmp | |
with open(f"/tmp/{jobname}.yaml", "w+") as fp: | |
# need to convert single quotes to double quotes | |
fp.write(yaml.dump(yaml_dict).replace("'", "\"")) | |
# write baseconfig | |
with open(f"/tmp/base.yaml", "w+") as fp: | |
fp.write(baseconfig) | |
list_of_input_files.append(f"/tmp/{jobname}.yaml") | |
list_of_input_files.append(f"/tmp/base.yaml") | |
# convert dictionary to YAML | |
with zipfile.ZipFile(os.path.join("/tmp/", f"{jobname}.zip"), 'w') as zip_archive: | |
for file in set(list_of_input_files): | |
zip_archive.write(file, arcname= os.path.join(jobname,os.path.basename(file)),compress_type=zipfile.ZIP_DEFLATED) | |
return yaml.dump(yaml_dict).replace("'", "\""),os.path.join("/tmp/", f"{jobname}.zip") | |
def convert_bfactors(pdb_path): | |
with open(pdb_path, 'r') as f: | |
lines = f.readlines() | |
for i,line in enumerate(lines): | |
# multiple each bfactor by 100 | |
if line[0:6] == 'ATOM ' or line[0:6] == 'HETATM': | |
bfactor = float(line[60:66]) | |
bfactor *= 100 | |
line = line[:60] + f'{bfactor:6.2f}' + line[66:] | |
lines[i] = line | |
with open(pdb_path.replace(".pdb", "_processed.pdb"), 'w') as f: | |
f.write(''.join(lines)) | |
def run_rf2aa(jobname, zip_archive): | |
current_dir = os.getcwd() | |
try: | |
with zipfile.ZipFile(zip_archive, 'r') as zip_ref: | |
zip_ref.extractall(os.path.join(current_dir)) | |
os.system(f"python -m rf2aa.run_inference --config-name {jobname}.yaml --config-path {current_dir}/{jobname}") | |
# scale pLDDT to 0-100 range in pdb output file | |
convert_bfactors(f"{current_dir}/{jobname}/{jobname}.pdb") | |
except Exception as e: | |
raise gr.Error(f"Error running RFAA: {e}") | |
return f"{current_dir}/{jobname}/{jobname}_processed.pdb" | |
def predict(input, jobname, dry_run, baseconfig, hard_case): | |
yaml_input, zip_archive = prepare_input(input, jobname, baseconfig, hard_case) | |
reps = [] | |
for chain in input["chains"]: | |
if chain["class"] in ["protein", "RNA", "DNA"]: | |
reps.append({ | |
"model": 0, | |
"chain": chain["chain"], | |
"resname": "", | |
"style": "cartoon", | |
"color": "alphafold", | |
"residue_range": "", | |
"around": 0, | |
"byres": False | |
}) | |
elif chain["class"] == "ligand" and "name" not in chain.keys(): | |
reps.append({ | |
"model": 0, | |
"chain": chain["chain"], | |
"resname": "LG1", | |
"style": "stick", | |
"color": "whiteCarbon", | |
"residue_range": "", | |
"around": 0, | |
"byres": False | |
}) | |
else: | |
reps.append({ | |
"model": 0, | |
"chain": chain["chain"], | |
"resname": "LG1", | |
"style": "sphere", | |
"color": "whiteCarbon", | |
"residue_range": "", | |
"around": 0, | |
"byres": False | |
}) | |
if dry_run: | |
return gr.Code(yaml_input, visible=True), gr.File(zip_archive, visible=True), gr.Markdown(f"""You can run your RFAA job using the following command: <pre>python -m rf2aa.run_inference --config-name {jobname}.yaml --config-path absolute/path/to/unzipped/{jobname}</pre>""", visible=True), Molecule3D(visible=False) | |
else: | |
pdb_file = run_rf2aa(jobname, zip_archive) | |
return gr.Code(yaml_input, visible=True), gr.File(zip_archive, visible=True),gr.Markdown(visible=False), Molecule3D(pdb_file,reps=reps,visible=True) | |
with gr.Blocks() as demo: | |
gr.Markdown("# RoseTTAFold All Atom UI") | |
gr.Markdown("""This UI allows you to generate input files for RoseTTAFold All Atom (RFAA) using the CofoldingInput widget. The input files can be used to run RFAA on your local machine. <br /> | |
If you launch the UI directly on your local machine you can also directly run the RFAA prediction. <br /> | |
More information in the official GitHub repository: [baker-laboratory/RoseTTAFold-All-Atom](https://github.com/baker-laboratory/RoseTTAFold-All-Atom) | |
""") | |
jobname = gr.Textbox("job1", label="Job Name") | |
with gr.Tab("Input"): | |
inp=CofoldingInput(label="Input") | |
hard_case = gr.Checkbox(False, label="Hard case (increase MAXCYCLE to 10)") | |
# only allow running the predictions if local | |
if os.environ.get("SPACE_HOST")!=None: | |
dry_run = gr.Checkbox(True, label="Only generate input files (dry run)", interactive=False) | |
else: | |
dry_run = gr.Checkbox(True, label="Only generate input files (dry run)") | |
with gr.Tab("Base config"): | |
base_config = gr.Code(baseconfig, label="Base config") | |
btn = gr.Button("Run") | |
config_file = gr.Code(label="YAML Hydra config for RFAA", visible=True) | |
runfiles = gr.File(label="files to run RFAA", visible=False) | |
instructions = gr.Markdown(visible=False) | |
out = Molecule3D(visible=False) | |
btn.click(predict, inputs=[inp, jobname, dry_run, base_config, hard_case], outputs=[config_file, runfiles, instructions, out]) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |