Christina Theodoris commited on
Commit
feeecd0
1 Parent(s): dc1481d

Add function to create remainder emb for in silico overexpression batch

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +43 -12
geneformer/in_silico_perturber.py CHANGED
@@ -140,6 +140,18 @@ def make_comparison_batch(original_emb, indices_to_perturb):
140
  all_embs_list += [torch.cat(emb_list)]
141
  return torch.stack(all_embs_list)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # average embedding position of goal cell states
144
  def get_cell_state_avg_embs(model,
145
  filtered_input_data,
@@ -188,6 +200,7 @@ def get_cell_state_avg_embs(model,
188
 
189
  # quantify cosine similarity of perturbed vs original or alternate states
190
  def quant_cos_sims(model,
 
191
  perturbation_batch,
192
  forward_batch_size,
193
  layer_to_quant,
@@ -226,8 +239,14 @@ def quant_cos_sims(model,
226
  minibatch_emb = outputs.hidden_states[layer_to_quant]
227
  if cell_states_to_model is None:
228
  minibatch_comparison = comparison_batch[i:max_range]
 
 
 
 
 
 
229
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
230
- else:
231
  for state in possible_states:
232
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
233
  del outputs
@@ -279,9 +298,9 @@ def pad_tensor_list(tensor_list, dynamic_or_constant, token_dictionary):
279
  class InSilicoPerturber:
280
  valid_option_dict = {
281
  "perturb_type": {"delete","overexpress","inhibit","activate"},
282
- "perturb_rank_shift": {None, int},
283
  "genes_to_perturb": {"all", list},
284
- "combos": {0,1,2},
285
  "anchor_gene": {None, str},
286
  "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
287
  "num_classes": {int},
@@ -326,7 +345,7 @@ class InSilicoPerturber:
326
  "overexpress": move gene to front of rank value encoding
327
  "inhibit": move gene to lower quartile of rank value encoding
328
  "activate": move gene to higher quartile of rank value encoding
329
- perturb_rank_shift : None, int
330
  Number of quartiles by which to shift rank of gene.
331
  For example, if perturb_type="activate" and perturb_rank_shift=1:
332
  genes in 4th quartile will move to middle of 3rd quartile.
@@ -414,6 +433,15 @@ class InSilicoPerturber:
414
  self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
415
 
416
  def validate_options(self):
 
 
 
 
 
 
 
 
 
417
  for attr_name,valid_options in self.valid_option_dict.items():
418
  attr_value = self.__dict__[attr_name]
419
  if type(attr_value) not in {list, dict}:
@@ -442,7 +470,7 @@ class InSilicoPerturber:
442
  elif self.perturb_type == "overexpress":
443
  logger.warning(
444
  "perturb_rank_shift set to None. " \
445
- "If perturb type is activate then gene is moved to front " \
446
  "of rank value encoding rather than shifted by quartile")
447
  self.perturb_rank_shift = None
448
 
@@ -626,13 +654,14 @@ class InSilicoPerturber:
626
  combo_lvl,
627
  self.nproc)
628
  cos_sims_data = quant_cos_sims(model,
629
- perturbation_batch,
630
- self.forward_batch_size,
631
- layer_to_quant,
632
- original_emb,
633
- indices_to_perturb,
634
- self.cell_states_to_model,
635
- state_embs_dict)
 
636
 
637
  if self.cell_states_to_model is None:
638
  # update cos sims dict
@@ -699,6 +728,7 @@ class InSilicoPerturber:
699
  0,
700
  self.nproc)
701
  cos_sims_data = quant_cos_sims(model,
 
702
  perturbation_batch,
703
  self.forward_batch_size,
704
  layer_to_quant,
@@ -715,6 +745,7 @@ class InSilicoPerturber:
715
  1,
716
  self.nproc)
717
  combo_cos_sims_data = quant_cos_sims(model,
 
718
  combo_perturbation_batch,
719
  self.forward_batch_size,
720
  layer_to_quant,
 
140
  all_embs_list += [torch.cat(emb_list)]
141
  return torch.stack(all_embs_list)
142
 
143
+ # perturbed cell emb removing the activated/overexpressed/inhibited gene emb
144
+ # so that only non-perturbed gene embeddings are compared to each other
145
+ # in original or perturbed context
146
+ def make_perturbed_remainder_batch(emb_batch, indices_to_remove):
147
+ if type(indices_to_remove) == int:
148
+ indices_to_keep = [i for i in range(emb_batch.size()[1])]
149
+ indices_to_keep.pop(indices_to_remove)
150
+ perturbed_remainder_batch = torch.stack([emb[indices_to_keep,:] for emb in emb_batch])
151
+ elif type(indices_to_remove) == list:
152
+ perturbed_remainder_batch = torch.stack([make_comparison_batch(emb_batch[i],indices_to_remove[i]) for i in range(len(emb_batch))])
153
+ return perturbed_remainder_batch
154
+
155
  # average embedding position of goal cell states
156
  def get_cell_state_avg_embs(model,
157
  filtered_input_data,
 
200
 
201
  # quantify cosine similarity of perturbed vs original or alternate states
202
  def quant_cos_sims(model,
203
+ perturb_type,
204
  perturbation_batch,
205
  forward_batch_size,
206
  layer_to_quant,
 
239
  minibatch_emb = outputs.hidden_states[layer_to_quant]
240
  if cell_states_to_model is None:
241
  minibatch_comparison = comparison_batch[i:max_range]
242
+ if perturb_type == "overexpress":
243
+ index_to_remove = 0
244
+ minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
245
+ # elif (perturb_type == "inhibit") or (perturb_type == "activate"):
246
+ # index_to_remove = placeholder
247
+ # minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
248
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
249
+ elif cell_states_to_model is not None:
250
  for state in possible_states:
251
  cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
252
  del outputs
 
298
  class InSilicoPerturber:
299
  valid_option_dict = {
300
  "perturb_type": {"delete","overexpress","inhibit","activate"},
301
+ "perturb_rank_shift": {None, 1, 2, 3},
302
  "genes_to_perturb": {"all", list},
303
+ "combos": {0, 1, 2},
304
  "anchor_gene": {None, str},
305
  "model_type": {"Pretrained","GeneClassifier","CellClassifier"},
306
  "num_classes": {int},
 
345
  "overexpress": move gene to front of rank value encoding
346
  "inhibit": move gene to lower quartile of rank value encoding
347
  "activate": move gene to higher quartile of rank value encoding
348
+ perturb_rank_shift : None, {1,2,3}
349
  Number of quartiles by which to shift rank of gene.
350
  For example, if perturb_type="activate" and perturb_rank_shift=1:
351
  genes in 4th quartile will move to middle of 3rd quartile.
 
433
  self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
434
 
435
  def validate_options(self):
436
+ # first disallow options under development
437
+ if self.perturb_type in ["inhibit", "activate"]:
438
+ logger.error(
439
+ f"In silico inhibition and activation currently under developemnt. " \
440
+ f"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
441
+ )
442
+ raise
443
+
444
+ # confirm arguments are within valid options and compatible with each other
445
  for attr_name,valid_options in self.valid_option_dict.items():
446
  attr_value = self.__dict__[attr_name]
447
  if type(attr_value) not in {list, dict}:
 
470
  elif self.perturb_type == "overexpress":
471
  logger.warning(
472
  "perturb_rank_shift set to None. " \
473
+ "If perturb type is overexpress then gene is moved to front " \
474
  "of rank value encoding rather than shifted by quartile")
475
  self.perturb_rank_shift = None
476
 
 
654
  combo_lvl,
655
  self.nproc)
656
  cos_sims_data = quant_cos_sims(model,
657
+ self.perturb_type,
658
+ perturbation_batch,
659
+ self.forward_batch_size,
660
+ layer_to_quant,
661
+ original_emb,
662
+ indices_to_perturb,
663
+ self.cell_states_to_model,
664
+ state_embs_dict)
665
 
666
  if self.cell_states_to_model is None:
667
  # update cos sims dict
 
728
  0,
729
  self.nproc)
730
  cos_sims_data = quant_cos_sims(model,
731
+ self.perturb_type,
732
  perturbation_batch,
733
  self.forward_batch_size,
734
  layer_to_quant,
 
745
  1,
746
  self.nproc)
747
  combo_cos_sims_data = quant_cos_sims(model,
748
+ self.perturb_type,
749
  combo_perturbation_batch,
750
  self.forward_batch_size,
751
  layer_to_quant,