Files changed (1) hide show
  1. app.py +53 -16
app.py CHANGED
@@ -8,7 +8,8 @@ import matplotlib
8
 
9
  matplotlib.use("Agg")
10
  import matplotlib.pyplot as plt
11
- from transformers import pipeline as pl
 
12
  from GPUtil import showUtilization as gpu_usage
13
 
14
  import pandas as pd
@@ -20,6 +21,7 @@ import torch
20
  import gc
21
  import jax
22
  from numba import cuda
 
23
  print('GPU available',torch.cuda.is_available())
24
  #print('__CUDA Device Name:',torch.cuda.get_device_name(0))
25
  print(os.getcwd())
@@ -81,27 +83,58 @@ def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
81
  f.write(protein.to_pdb(unrelaxed_protein))
82
  return plddts
83
 
 
 
 
 
 
 
 
84
  @ray.remote(num_gpus=1, max_calls=1)
85
  def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
86
  print("running protgpt2")
87
  print(gpu_usage())
88
- protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
89
- sequences = protgpt2(
90
- startsequence,
91
- max_length=length,
92
- do_sample=True,
93
- top_k=top_k_poolsize,
94
- repetition_penalty=repetitionPenalty,
95
- num_return_sequences=max_seqs,
96
- eos_token_id=0,
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  print("Cleaning up after protGPT2")
99
  #print(gpu_usage())
100
  #torch.cuda.empty_cache()
101
  #device = cuda.get_current_device()
102
  #device.reset()
103
  #print(gpu_usage())
104
- return sequences
105
 
106
  @ray.remote(num_gpus=1, max_calls=1)
107
  def run_alphafold(startsequence):
@@ -140,9 +173,13 @@ def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs):
140
  seqlen = length
141
  generated_seqs = ray.get(run_protgpt2.remote(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs))
142
  gen_seqs = [x["generated_text"] for x in generated_seqs]
143
- print(gen_seqs)
 
 
 
 
144
  sequencestxt = ""
145
- for i, seq in enumerate(gen_seqs):
146
  s = seq.replace("\n","")
147
  seqlen = len(s)
148
  s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
@@ -360,8 +397,8 @@ with proteindream:
360
  )
361
  with gr.Box():
362
  with gr.Row():
363
- inp = gr.Textbox(placeholder="M", label="Start sequence")
364
- length = gr.Number(value=50, label="Max sequence length")
365
  with gr.Row():
366
  repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty")
367
  top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size")
 
8
 
9
  matplotlib.use("Agg")
10
  import matplotlib.pyplot as plt
11
+ #from transformers import pipeline as pl
12
+ from transformers import GPT2LMHeadModel , GPT2Tokenizer
13
  from GPUtil import showUtilization as gpu_usage
14
 
15
  import pandas as pd
 
21
  import gc
22
  import jax
23
  from numba import cuda
24
+ import math
25
  print('GPU available',torch.cuda.is_available())
26
  #print('__CUDA Device Name:',torch.cuda.get_device_name(0))
27
  print(os.getcwd())
 
83
  f.write(protein.to_pdb(unrelaxed_protein))
84
  return plddts
85
 
86
+ def compute_perplexity(model, tokenizer, sequence):
87
+ input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
88
+ with torch.no_grad():
89
+ outputs = model(input_ids, labels=input_ids)
90
+ loss, logits = outputs[:2]
91
+ return math.exp(loss)
92
+
93
  @ray.remote(num_gpus=1, max_calls=1)
94
  def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
95
  print("running protgpt2")
96
  print(gpu_usage())
97
+ seqs_to_sample = max_seqs*10 # get the top 10
98
+ #protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
99
+ model = GPT2LMHeadModel.from_pretrained("nferruz/ProtGPT2")
100
+ tokenizer = GPT2Tokenizer.from_pretrained("nferruz/ProtGPT2")
101
+ input_ids = tokenizer.encode(startsequence, return_tensors='pt')
102
+
103
+ sequences = model.generate(input_ids,
104
+ max_length=length,
105
+ do_sample=True,
106
+ top_k=top_k_poolsize,
107
+ repetition_penalty=repetitionPenalty,
108
+ num_return_sequences=seqs_to_sample,
109
+ eos_token_id=0)
110
+ filtered_sequences = []
111
+ for sequence in sequences:
112
+ decoded_seq = tokenizer.decode(seq)
113
+ # No newlines in first line and avoid truncation
114
+ if '\n' not in decoded_seq[0:60] and decoded_seq.count('<|endoftext|>')>=2:
115
+ clean_seq = decoded_seq.split('<|endoftext|>')[0]
116
+ ppl = compute_perplexity(model, tokenizer, clean_seq)
117
+ filtered_sequences.append((clean_seq, ppl/len(clean_seq)))
118
+
119
+ ## THis needs to be fixed to show warning if not enough sequences fulfill the criteria!
120
+ selected_sequences = filtered_sequences.sort(key = lambda x: x[2])[:max_seqs]
121
+ # sequences = protgpt2(
122
+ # startsequence,
123
+ # max_length=length,
124
+ # do_sample=True,
125
+ # top_k=top_k_poolsize,
126
+ # repetition_penalty=repetitionPenalty,
127
+ # num_return_sequences=seqs_to_sample,
128
+ # eos_token_id=0,
129
+ # )
130
+
131
  print("Cleaning up after protGPT2")
132
  #print(gpu_usage())
133
  #torch.cuda.empty_cache()
134
  #device = cuda.get_current_device()
135
  #device.reset()
136
  #print(gpu_usage())
137
+ return selected_sequences
138
 
139
  @ray.remote(num_gpus=1, max_calls=1)
140
  def run_alphafold(startsequence):
 
173
  seqlen = length
174
  generated_seqs = ray.get(run_protgpt2.remote(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs))
175
  gen_seqs = [x["generated_text"] for x in generated_seqs]
176
+ # Make sure sequences weren't truncated due to the length cutoff
177
+
178
+ # Select the best scoring top 10th:
179
+
180
+ print(sel_seqs)
181
  sequencestxt = ""
182
+ for i, seq in enumerate(sel_seqs):
183
  s = seq.replace("\n","")
184
  seqlen = len(s)
185
  s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
 
397
  )
398
  with gr.Box():
399
  with gr.Row():
400
+ inp = gr.Textbox(placeholder="MTYKLILNGKTLKGETTT", label="Start sequence")
401
+ length = gr.Number(value=100, label="Max sequence length")
402
  with gr.Row():
403
  repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty")
404
  top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size")