sagawa commited on
Commit
1d7f8f6
1 Parent(s): 482be2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -42
app.py CHANGED
@@ -28,67 +28,50 @@ class Config:
28
  self.seed = 42
29
 
30
 
 
31
  def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
32
- results = {
33
- "file_name": [],
34
- "raw prediction value": [],
35
- "binary prediction value": [],
36
- }
37
  file_names = []
38
  input_sequences = []
39
 
40
- os.system("chmod 777 bin/foldseek")
41
  for pdb_file in pdb_files:
42
  pdb_path = pdb_file.name
 
43
  sequences = get_foldseek_seq(pdb_path)
44
-
45
- file_name = os.path.basename(pdb_path)
46
  if not sequences:
47
- results["file_name"].append(file_name)
48
  results["raw prediction value"].append(None)
49
  results["binary prediction value"].append(None)
50
  continue
51
 
52
  sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
53
- file_names.append(file_name)
54
  input_sequences.append(sequence)
55
 
56
- raw_pred, binary_pred = predict_stability_core(
57
- model_choice, organism_choice, input_sequences, cfg
58
- )
59
  results["file_name"] = results["file_name"] + file_names
60
- results["raw prediction value"] = results["raw prediction value"] + raw_pred
61
- results["binary prediction value"] = (
62
- results["binary prediction value"] + binary_pred
63
- )
64
-
65
  df = pd.DataFrame(results)
66
  output_csv = "/tmp/predictions.csv"
67
  df.to_csv(output_csv, index=False)
68
 
69
  return output_csv
70
 
71
-
72
- def predict_stability_with_sequence(
73
- model_choice, organism_choice, sequence, cfg=Config()
74
- ):
75
- if not sequence:
76
- return "No valid sequence provided."
77
  try:
78
- raw_pred, binary_pred = predict_stability_core(
79
- model_choice, organism_choice, [sequence], cfg
80
- )
81
- df = pd.DataFrame(
82
- {
83
- "sequence": sequence,
84
- "raw prediction value": raw_pred,
85
- "binary prediction value": binary_pred,
86
- }
87
- )
88
  output_csv = "/tmp/predictions.csv"
89
  df.to_csv(output_csv, index=False)
90
 
91
- return output_csv
92
  except Exception as e:
93
  return f"An error occurred: {str(e)}"
94
 
@@ -127,7 +110,6 @@ def predict(cfg, sequences):
127
  cfg.model_path, padding_side=cfg.padding_side
128
  )
129
  cfg.tokenizer = tokenizer
130
-
131
  dataset = PLTNUMDataset(cfg, df, train=False)
132
  dataloader = DataLoader(
133
  dataset,
@@ -144,9 +126,9 @@ def predict(cfg, sequences):
144
  model.eval()
145
  predictions = []
146
 
147
- with torch.no_grad():
148
- for inputs, _ in dataloader:
149
- inputs = inputs.to(cfg.device)
150
  with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
151
  preds = (
152
  torch.sigmoid(model(inputs))
@@ -156,7 +138,7 @@ def predict(cfg, sequences):
156
  predictions += preds.cpu().tolist()
157
 
158
  predictions = list(itertools.chain.from_iterable(predictions))
159
-
160
  return predictions, [1 if x > 0.5 else 0 for x in predictions]
161
 
162
 
@@ -192,7 +174,9 @@ with gr.Blocks() as demo:
192
  gr.Markdown("### Upload your PDB files:")
193
  pdb_files = gr.File(label="Upload PDB Files", file_count="multiple")
194
  predict_button = gr.Button("Predict Stability")
195
- prediction_output = gr.File(label="Download Predictions")
 
 
196
 
197
  predict_button.click(
198
  fn=predict_stability_with_pdb,
@@ -208,7 +192,9 @@ with gr.Blocks() as demo:
208
  lines=8,
209
  )
210
  predict_button = gr.Button("Predict Stability")
211
- prediction_output = gr.File(label="Download Predictions")
 
 
212
 
213
  predict_button.click(
214
  fn=predict_stability_with_sequence,
 
28
  self.seed = 42
29
 
30
 
31
+
32
  def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
33
+ results = {"file_name": [],
34
+ "raw prediction value": [],
35
+ "binary prediction value": []
36
+ }
 
37
  file_names = []
38
  input_sequences = []
39
 
 
40
  for pdb_file in pdb_files:
41
  pdb_path = pdb_file.name
42
+ os.system("chmod 777 bin/foldseek")
43
  sequences = get_foldseek_seq(pdb_path)
 
 
44
  if not sequences:
45
+ results["file_name"].append(pdb_file.name.split("/")[-1])
46
  results["raw prediction value"].append(None)
47
  results["binary prediction value"].append(None)
48
  continue
49
 
50
  sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
51
+ file_names.append(pdb_file.name.split("/")[-1])
52
  input_sequences.append(sequence)
53
 
54
+ raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, input_sequences, cfg)
 
 
55
  results["file_name"] = results["file_name"] + file_names
56
+ results["raw prediction value"] = results["raw prediction value"] + raw_prediction
57
+ results["binary prediction value"] = results["binary prediction value"] + binary_prediction
58
+
 
 
59
  df = pd.DataFrame(results)
60
  output_csv = "/tmp/predictions.csv"
61
  df.to_csv(output_csv, index=False)
62
 
63
  return output_csv
64
 
65
+ def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()):
 
 
 
 
 
66
  try:
67
+ if not sequence:
68
+ return "No valid sequence provided."
69
+ raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, [sequence], cfg)
70
+ df = pd.DataFrame({"sequence": sequence, "raw prediction value": raw_prediction, "binary prediction value": binary_prediction})
 
 
 
 
 
 
71
  output_csv = "/tmp/predictions.csv"
72
  df.to_csv(output_csv, index=False)
73
 
74
+ return output_csv
75
  except Exception as e:
76
  return f"An error occurred: {str(e)}"
77
 
 
110
  cfg.model_path, padding_side=cfg.padding_side
111
  )
112
  cfg.tokenizer = tokenizer
 
113
  dataset = PLTNUMDataset(cfg, df, train=False)
114
  dataloader = DataLoader(
115
  dataset,
 
126
  model.eval()
127
  predictions = []
128
 
129
+ for inputs, _ in dataloader:
130
+ inputs = inputs.to(cfg.device)
131
+ with torch.no_grad():
132
  with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
133
  preds = (
134
  torch.sigmoid(model(inputs))
 
138
  predictions += preds.cpu().tolist()
139
 
140
  predictions = list(itertools.chain.from_iterable(predictions))
141
+
142
  return predictions, [1 if x > 0.5 else 0 for x in predictions]
143
 
144
 
 
174
  gr.Markdown("### Upload your PDB files:")
175
  pdb_files = gr.File(label="Upload PDB Files", file_count="multiple")
176
  predict_button = gr.Button("Predict Stability")
177
+ prediction_output = gr.File(
178
+ label="Download Predictions"
179
+ )
180
 
181
  predict_button.click(
182
  fn=predict_stability_with_pdb,
 
192
  lines=8,
193
  )
194
  predict_button = gr.Button("Predict Stability")
195
+ prediction_output = gr.File(
196
+ label="Download Predictions"
197
+ )
198
 
199
  predict_button.click(
200
  fn=predict_stability_with_sequence,