Spaces:
Runtime error
Runtime error
import os, time, sys | |
if not os.path.isfile("RF2_apr23.pt"): | |
# send param download into background | |
os.system( | |
"(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &" | |
) | |
if not os.path.isdir("RoseTTAFold2"): | |
print("install RoseTTAFold2") | |
os.system("git clone https://github.com/sokrypton/RoseTTAFold2.git") | |
print(os.listdir("RoseTTAFold2")) | |
os.system( | |
"cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install ." | |
) | |
os.system( | |
"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py" | |
) | |
# install hhsuite | |
print("install hhsuite") | |
os.makedirs("hhsuite", exist_ok=True) | |
os.system( | |
f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/" | |
) | |
print(os.listdir("hhsuite")) | |
if os.path.isfile(f"RF2_apr23.pt.aria2"): | |
print("downloading RoseTTAFold2 params") | |
while os.path.isfile(f"RF2_apr23.pt.aria2"): | |
time.sleep(5) | |
os.environ["DGLBACKEND"] = "pytorch" | |
sys.path.append("RoseTTAFold2/network") | |
if "hhsuite" not in os.environ["PATH"]: | |
os.environ["PATH"] += ":hhsuite/bin:hhsuite/scripts" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from parsers import parse_a3m | |
from api import run_mmseqs2 | |
import torch | |
from string import ascii_uppercase, ascii_lowercase | |
import hashlib, re, os | |
import random | |
from Bio.PDB import * | |
def get_hash(x): | |
return hashlib.sha1(x.encode()).hexdigest() | |
alphabet_list = list(ascii_uppercase + ascii_lowercase) | |
from collections import OrderedDict, Counter | |
import gradio as gr | |
if not "pred" in dir(): | |
from predict import Predictor | |
print("compile RoseTTAFold2") | |
model_params = "RF2_apr23.pt" | |
if torch.cuda.is_available(): | |
pred = Predictor(model_params, torch.device("cuda:0")) | |
else: | |
print("WARNING: using CPU") | |
pred = Predictor(model_params, torch.device("cpu")) | |
def get_unique_sequences(seq_list): | |
unique_seqs = list(OrderedDict.fromkeys(seq_list)) | |
return unique_seqs | |
def get_msa(seq, jobname, cov=50, id=90, max_msa=2048, mode="unpaired_paired"): | |
assert mode in ["unpaired", "paired", "unpaired_paired"] | |
seqs = [seq] if isinstance(seq, str) else seq | |
# collapse homooligomeric sequences | |
counts = Counter(seqs) | |
u_seqs = list(counts.keys()) | |
u_nums = list(counts.values()) | |
# expand homooligomeric sequences | |
first_seq = "/".join(sum([[x] * n for x, n in zip(u_seqs, u_nums)], [])) | |
msa = [first_seq] | |
path = os.path.join(jobname, "msa") | |
os.makedirs(path, exist_ok=True) | |
if mode in ["paired", "unpaired_paired"] and len(u_seqs) > 1: | |
print("getting paired MSA") | |
out_paired = run_mmseqs2(u_seqs, f"{path}/", use_pairing=True) | |
headers, sequences = [], [] | |
for a3m_lines in out_paired: | |
n = -1 | |
for line in a3m_lines.split("\n"): | |
if len(line) > 0: | |
if line.startswith(">"): | |
n += 1 | |
if len(headers) < (n + 1): | |
headers.append([]) | |
sequences.append([]) | |
headers[n].append(line) | |
else: | |
sequences[n].append(line) | |
# filter MSA | |
with open(f"{path}/paired_in.a3m", "w") as handle: | |
for n, sequence in enumerate(sequences): | |
handle.write(f">n{n}\n{''.join(sequence)}\n") | |
os.system( | |
f"hhfilter -i {path}/paired_in.a3m -id {id} -cov {cov} -o {path}/paired_out.a3m" | |
) | |
with open(f"{path}/paired_out.a3m", "r") as handle: | |
for line in handle: | |
if line.startswith(">"): | |
n = int(line[2:]) | |
xs = sequences[n] | |
# expand homooligomeric sequences | |
xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)] | |
msa.append("/".join(xs)) | |
if len(msa) < max_msa and ( | |
mode in ["unpaired", "unpaired_paired"] or len(u_seqs) == 1 | |
): | |
print("getting unpaired MSA") | |
out = run_mmseqs2(u_seqs, f"{path}/") | |
Ls = [len(seq) for seq in u_seqs] | |
sub_idx = [] | |
sub_msa = [] | |
sub_msa_num = 0 | |
for n, a3m_lines in enumerate(out): | |
sub_msa.append([]) | |
with open(f"{path}/in_{n}.a3m", "w") as handle: | |
handle.write(a3m_lines) | |
# filter | |
os.system( | |
f"hhfilter -i {path}/in_{n}.a3m -id {id} -cov {cov} -o {path}/out_{n}.a3m" | |
) | |
with open(f"{path}/out_{n}.a3m", "r") as handle: | |
for line in handle: | |
if not line.startswith(">"): | |
xs = ["-" * l for l in Ls] | |
xs[n] = line.rstrip() | |
# expand homooligomeric sequences | |
xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)] | |
sub_msa[-1].append("/".join(xs)) | |
sub_msa_num += 1 | |
sub_idx.append(list(range(len(sub_msa[-1])))) | |
while len(msa) < max_msa and sub_msa_num > 0: | |
for n in range(len(sub_idx)): | |
if len(sub_idx[n]) > 0: | |
msa.append(sub_msa[n][sub_idx[n].pop(0)]) | |
sub_msa_num -= 1 | |
if len(msa) == max_msa: | |
break | |
with open(f"{jobname}/msa.a3m", "w") as handle: | |
for n, sequence in enumerate(msa): | |
handle.write(f">n{n}\n{sequence}\n") | |
from Bio.PDB.PDBExceptions import PDBConstructionWarning | |
import warnings | |
from Bio.PDB import * | |
import numpy as np | |
def add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname): | |
pdb_parser = PDBParser() | |
warnings.filterwarnings("ignore", category=PDBConstructionWarning) | |
structure = pdb_parser.get_structure( | |
"pdb", f"{jobname}/rf2_seed{best_seed}_00_pred.pdb" | |
) | |
io = MMCIFIO() | |
io.set_structure(structure) | |
io.save(f"{jobname}/rf2_seed{best_seed}_00_pred.cif") | |
plddt_cif = f"""# | |
loop_ | |
_ma_qa_metric.id | |
_ma_qa_metric.mode | |
_ma_qa_metric.name | |
_ma_qa_metric.software_group_id | |
_ma_qa_metric.type | |
1 global pLDDT 1 pLDDT | |
2 local pLDDT 1 pLDDT | |
# | |
_ma_qa_metric_global.metric_id 1 | |
_ma_qa_metric_global.metric_value {best_plddt:.3f} | |
_ma_qa_metric_global.model_id 1 | |
_ma_qa_metric_global.ordinal_id 1 | |
# | |
loop_ | |
_ma_qa_metric_local.label_asym_id | |
_ma_qa_metric_local.label_comp_id | |
_ma_qa_metric_local.label_seq_id | |
_ma_qa_metric_local.metric_id | |
_ma_qa_metric_local.metric_value | |
_ma_qa_metric_local.model_id | |
_ma_qa_metric_local.ordinal_id""" | |
for chain in structure[0]: | |
for i, residue in enumerate(chain): | |
plddt_cif += f"\n{chain.id} {residue.resname} {residue.id[1]} 2 {best_plddts[i]*100:.2f} 1 {residue.id[1]}" | |
plddt_cif += "\n#" | |
with open(f"{jobname}/rf2_seed{best_seed}_00_pred.cif", "a") as f: | |
f.write(plddt_cif) | |
def predict( | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
mode="web", | |
): | |
if os.path.exists("/home/user/app"): # crude check if on spaces | |
if len(sequence) > 600: | |
raise gr.Error( | |
f"Your sequence is too long ({len(sequence)}). " | |
"Please use the full version of RoseTTAfold2 directly from GitHub." | |
) | |
random_seed = int(random_seed) | |
num_models = int(num_models) | |
max_msa = int(max_msa) | |
num_recycles = int(num_recycles) | |
order = int(order) | |
max_extra_msa = max_msa * 8 | |
print("sequence", sequence) | |
sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper()) | |
sequence = re.sub(":+", ":", sequence) | |
sequence = re.sub("^[:]+", "", sequence) | |
sequence = re.sub("[:]+$", "", sequence) | |
print("sequence", sequence) | |
if sym in ["X", "C"]: | |
copies = int(order) | |
elif sym in ["D"]: | |
copies = int(order) * 2 | |
else: | |
copies = {"T": 12, "O": 24, "I": 60}[sym] | |
order = "" | |
symm = sym + str(order) | |
sequences = sequence.replace(":", "/").split("/") | |
if collapse_identical: | |
u_sequences = get_unique_sequences(sequences) | |
else: | |
u_sequences = sequences | |
sequences = sum([u_sequences] * copies, []) | |
lengths = [len(s) for s in sequences] | |
# TODO | |
subcrop = 1000 if sum(lengths) > 1400 else -1 | |
sequence = "/".join(sequences) | |
jobname = jobname + "_" + symm + "_" + get_hash(sequence)[:5] | |
print(f"jobname: {jobname}") | |
print(f"lengths: {lengths}") | |
print("final_sequence", u_sequences) | |
os.makedirs(jobname, exist_ok=True) | |
if msa_method == "mmseqs2": | |
get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa) | |
elif msa_method == "single_sequence": | |
u_sequence = "/".join(u_sequences) | |
with open(f"{jobname}/msa.a3m", "w") as a3m: | |
a3m.write(f">{jobname}\n{u_sequence}\n") | |
# elif msa_method == "custom_a3m": | |
# print("upload custom a3m") | |
# # msa_dict = files.upload() | |
# lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines() | |
# a3m_lines = [] | |
# for line in lines: | |
# line = line.replace("\x00", "") | |
# if len(line) > 0 and not line.startswith("#"): | |
# a3m_lines.append(line) | |
# with open(f"{jobname}/msa.a3m", "w") as a3m: | |
# a3m.write("\n".join(a3m_lines)) | |
best_plddt = None | |
best_seed = None | |
for seed in range(int(random_seed), int(random_seed) + int(num_models)): | |
torch.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
npz = f"{jobname}/rf2_seed{seed}_00.npz" | |
mlm = 0.15 if use_mlm else 0 | |
print("MLM", mlm, use_mlm) | |
pred.predict( | |
inputs=[f"{jobname}/msa.a3m"], | |
out_prefix=f"{jobname}/rf2_seed{seed}", | |
symm=symm, | |
ffdb=None, # TODO (templates), | |
n_recycles=num_recycles, | |
msa_mask=0.15 if use_mlm else 0, | |
msa_concat_mode=msa_concat_mode, | |
nseqs=max_msa, | |
nseqs_full=max_extra_msa, | |
subcrop=subcrop, | |
is_training=use_dropout, | |
) | |
plddt = np.load(npz)["lddt"].mean() | |
if best_plddt is None or plddt > best_plddt: | |
best_plddt = plddt | |
best_plddts = np.load(npz)["lddt"] | |
best_seed = seed | |
if mode == "web": | |
# Mol* only displays AlphaFold plDDT if they are in a cif. | |
pdb_parser = PDBParser() | |
mmcif_parser = MMCIFParser() | |
plddt_cif = add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname) | |
return f"{jobname}/rf2_seed{best_seed}_00_pred.cif" | |
else: | |
# for api just return a pdb file | |
return f"{jobname}/rf2_seed{best_seed}_00_pred.pdb" | |
def predict_api( | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
): | |
filename = predict( | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
mode="api", | |
) | |
with open(f"{filename}") as fp: | |
return fp.read() | |
def molecule(input_pdb, public_link): | |
print(input_pdb) | |
print(public_link + "/file=" + input_pdb) | |
link = public_link + "/file=" + input_pdb | |
x = ( | |
"""<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="utf-8" /> | |
<meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0"> | |
<title>PDBe Molstar - Helper functions</title> | |
<!-- Molstar CSS & JS --> | |
<link rel="stylesheet" type="text/css" href="https://www.ebi.ac.uk/pdbe/pdb-component-library/css/pdbe-molstar-light-3.1.0.css"> | |
<script type="text/javascript" src="https://www.ebi.ac.uk/pdbe/pdb-component-library/js/pdbe-molstar-plugin-3.1.0.js"></script> | |
<style> | |
* { | |
margin: 0; | |
padding: 0; | |
box-sizing: border-box; | |
} | |
.msp-plugin ::-webkit-scrollbar-thumb { | |
background-color: #474748 !important; | |
} | |
.viewerSection { | |
margin: 120px 0 0 0px; | |
} | |
#myViewer{ | |
float:left; | |
width:100%; | |
height: 800px; | |
position:relative; | |
} | |
.btn{ | |
font-family: "Open Sans", sans-serif; | |
display: inline-block; | |
outline: none; | |
cursor: pointer; | |
font-weight: 600; | |
border-radius: 3px; | |
padding: 12px 24px; | |
border: 0; | |
margin:0 10px; | |
line-height: 1.15; | |
font-size: 16px; | |
text-decoration: none; | |
} | |
.btn-orange{ | |
background: #ff5000; | |
color: #fff; | |
} | |
.btn-gray{ | |
color: #3a4149; | |
background: #e7ebee; | |
} | |
.btn:hover{ | |
transition: all .1s ease; | |
box-shadow: 0 0 0 0 #fff, 0 0 0 3px #ddd;} | |
.text-center{ | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
padding: 20px 0; | |
} | |
.flex{ | |
padding: 10px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
width:fit-content; | |
} | |
.flex svg{ | |
margin-right: 10px; | |
width:16px; | |
height:16px; | |
} | |
.flex a{ | |
margin:0 10px; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="text-center"> | |
<a class="btn btn-orange flex" href=\"""" | |
+ link | |
+ """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true"> | |
<path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path> | |
</svg> <span>CIF File</span></a> | |
<a class="btn btn-gray flex" href=\"""" | |
+ link.replace(".cif", ".pdb") | |
+ """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true"> | |
<path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path> | |
</svg> <span>PDB File</span></a> | |
</div> | |
<div class="viewerSection"> | |
<!-- Molstar container --> | |
<div id="myViewer"></div> | |
</div> | |
<script> | |
//Create plugin instance | |
var viewerInstance = new PDBeMolstarPlugin(); | |
//Set options (Checkout available options list in the documentation) | |
var options = { | |
customData: { | |
url: \"""" | |
+ link | |
+ """\", | |
format: "cif" | |
}, | |
alphafoldView: true, | |
bgColor: {r:255, g:255, b:255}, | |
//hideCanvasControls: ["selection", "animation", "controlToggle", "controlInfo"] | |
} | |
//Get element from HTML/Template to place the viewer | |
var viewerContainer = document.getElementById("myViewer"); | |
//Call render method to display the 3D view | |
viewerInstance.render(viewerContainer, options); | |
</script> | |
</body> | |
</html>""" | |
) | |
return f"""<iframe style="width: 100%; height: 1000px" name="result" allow="midi; geolocation; microphone; camera; | |
display-capture; encrypted-media;" sandbox="allow-modals allow-forms | |
allow-scripts allow-same-origin allow-popups | |
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" | |
def predict_web( | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
): | |
if os.path.exists("/home/user/app"): | |
public_link = "https://simonduerr-rosettafold2.hf.space" | |
else: | |
public_link = "http://localhost:7860" | |
filename = predict( | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
mode="web", | |
) | |
return molecule(filename, public_link) | |
with gr.Blocks() as rosettafold: | |
gr.Markdown("# RoseTTAFold2") | |
gr.Markdown( | |
"""If using please cite: [manuscript](https://www.biorxiv.org/content/10.1101/2023.05.24.542179v1) | |
<br> Heavily based on [RoseTTAFold2 ColabFold notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/RoseTTAFold2.ipynb)""" | |
) | |
with gr.Accordion("How to use in PyMol", open=False): | |
gr.HTML( | |
"""<code>os.system('wget https://huggingface.co/spaces/simonduerr/rosettafold2/raw/main/rosettafold_pymol.py') <br> | |
run rosettafold_pymol.py <br> | |
rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models] <br> | |
color_plddt jobname</code> | |
""" | |
) | |
sequence = gr.Textbox( | |
label="sequence", | |
value="PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK", | |
) | |
jobname = gr.Textbox(label="jobname", value="test") | |
with gr.Accordion("Additional settings", open=False): | |
sym = gr.Textbox(label="sym", value="X") | |
order = gr.Slider(label="order", value=1, step=1, minimum=1, maximum=12) | |
msa_concat_mode = gr.Dropdown( | |
label="msa_concat_mode", | |
value="default", | |
choices=["diag", "repeat", "default"], | |
) | |
msa_method = gr.Dropdown( | |
label="msa_method", | |
value="single_sequence", | |
choices=[ | |
"mmseqs2", | |
"single_sequence", | |
], # dont allow custom a3m for now , "custom_a3m" | |
) | |
pair_mode = gr.Dropdown( | |
label="pair_mode", | |
value="unpaired_paired", | |
choices=["unpaired_paired", "paired", "unpaired"], | |
) | |
num_recycles = gr.Dropdown( | |
label="num_recycles", value="6", choices=["0", "1", "3", "6", "12", "24"] | |
) | |
use_mlm = gr.Checkbox(label="use_mlm", value=False) | |
use_dropout = gr.Checkbox(label="use_dropout", value=False) | |
collapse_identical = gr.Checkbox(label="collapse_identical", value=False) | |
max_msa = gr.Dropdown( | |
choices=["16", "32", "64", "128", "256", "512"], | |
value="16", | |
label="max_msa", | |
) | |
random_seed = gr.Textbox(label="random_seed", value=0) | |
num_models = gr.Dropdown( | |
label="num_models", value="1", choices=["1", "2", "4", "8", "16", "32"] | |
) | |
btn = gr.Button("Run", visible=False) | |
btn_web = gr.Button("Run") | |
output_plain = gr.HTML() | |
output = gr.HTML() | |
btn.click( | |
fn=predict_api, | |
inputs=[ | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
], | |
outputs=output_plain, | |
api_name="rosettafold2", | |
) | |
btn_web.click( | |
fn=predict_web, | |
inputs=[ | |
sequence, | |
jobname, | |
sym, | |
order, | |
msa_concat_mode, | |
msa_method, | |
pair_mode, | |
collapse_identical, | |
num_recycles, | |
use_mlm, | |
use_dropout, | |
max_msa, | |
random_seed, | |
num_models, | |
], | |
outputs=output, | |
) | |
rosettafold.launch() | |