Added feature to perturb a set of indices to help with debugging and with very large runtimes
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -604,6 +604,7 @@ class InSilicoPerturber:
|
|
604 |
"filter_data": {None, dict},
|
605 |
"cell_states_to_model": {None, dict},
|
606 |
"max_ncells": {None, int},
|
|
|
607 |
"emb_layer": {-1, 0},
|
608 |
"forward_batch_size": {int},
|
609 |
"nproc": {int},
|
@@ -622,6 +623,7 @@ class InSilicoPerturber:
|
|
622 |
filter_data=None,
|
623 |
cell_states_to_model=None,
|
624 |
max_ncells=None,
|
|
|
625 |
emb_layer=-1,
|
626 |
forward_batch_size=100,
|
627 |
nproc=4,
|
@@ -687,6 +689,13 @@ class InSilicoPerturber:
|
|
687 |
max_ncells : None, int
|
688 |
Maximum number of cells to test.
|
689 |
If None, will test all cells.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
emb_layer : {-1, 0}
|
691 |
Embedding layer to use for quantification.
|
692 |
-1: 2nd to last layer (recommended for pretrained Geneformer)
|
@@ -723,6 +732,7 @@ class InSilicoPerturber:
|
|
723 |
self.filter_data = filter_data
|
724 |
self.cell_states_to_model = cell_states_to_model
|
725 |
self.max_ncells = max_ncells
|
|
|
726 |
self.emb_layer = emb_layer
|
727 |
self.forward_batch_size = forward_batch_size
|
728 |
self.nproc = nproc
|
@@ -886,7 +896,7 @@ class InSilicoPerturber:
|
|
886 |
if self.perturb_type in ["inhibit","activate"]:
|
887 |
if self.perturb_rank_shift is None:
|
888 |
logger.error(
|
889 |
-
"If
|
890 |
"quartile to shift by must be specified.")
|
891 |
raise
|
892 |
|
@@ -897,6 +907,18 @@ class InSilicoPerturber:
|
|
897 |
logger.warning(
|
898 |
"Values in filter_data dict must be lists. " \
|
899 |
f"Changing {key} value to list ([{value}]).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
900 |
|
901 |
def perturb_data(self,
|
902 |
model_directory,
|
@@ -995,6 +1017,15 @@ class InSilicoPerturber:
|
|
995 |
cos_sims_dict = defaultdict(list)
|
996 |
pickle_batch = -1
|
997 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
998 |
|
999 |
# make perturbation batch w/ single perturbation in multiple cells
|
1000 |
if self.perturb_group == True:
|
|
|
604 |
"filter_data": {None, dict},
|
605 |
"cell_states_to_model": {None, dict},
|
606 |
"max_ncells": {None, int},
|
607 |
+
"inds_to_perturb": {"all", dict},
|
608 |
"emb_layer": {-1, 0},
|
609 |
"forward_batch_size": {int},
|
610 |
"nproc": {int},
|
|
|
623 |
filter_data=None,
|
624 |
cell_states_to_model=None,
|
625 |
max_ncells=None,
|
626 |
+
inds_to_perturb="all",
|
627 |
emb_layer=-1,
|
628 |
forward_batch_size=100,
|
629 |
nproc=4,
|
|
|
689 |
max_ncells : None, int
|
690 |
Maximum number of cells to test.
|
691 |
If None, will test all cells.
|
692 |
+
inds_to_perturb : "all", list
|
693 |
+
Default is perturbing each cell in the dataset.
|
694 |
+
Otherwise, may provide a dict of indices of genes to perturb with keys start_ind and end_ind.
|
695 |
+
start_ind: the first index to perturb.
|
696 |
+
end_ind: the last index to perturb (exclusive).
|
697 |
+
Indices will be selected *after* the filter_data criteria and sorting.
|
698 |
+
Useful for splitting extremely large datasets across separate GPUs.
|
699 |
emb_layer : {-1, 0}
|
700 |
Embedding layer to use for quantification.
|
701 |
-1: 2nd to last layer (recommended for pretrained Geneformer)
|
|
|
732 |
self.filter_data = filter_data
|
733 |
self.cell_states_to_model = cell_states_to_model
|
734 |
self.max_ncells = max_ncells
|
735 |
+
self.inds_to_perturb = inds_to_perturb
|
736 |
self.emb_layer = emb_layer
|
737 |
self.forward_batch_size = forward_batch_size
|
738 |
self.nproc = nproc
|
|
|
896 |
if self.perturb_type in ["inhibit","activate"]:
|
897 |
if self.perturb_rank_shift is None:
|
898 |
logger.error(
|
899 |
+
"If perturb_type is inhibit or activate then " \
|
900 |
"quartile to shift by must be specified.")
|
901 |
raise
|
902 |
|
|
|
907 |
logger.warning(
|
908 |
"Values in filter_data dict must be lists. " \
|
909 |
f"Changing {key} value to list ([{value}]).")
|
910 |
+
|
911 |
+
if self.inds_to_perturb != "all":
|
912 |
+
if set(self.inds_to_perturb.keys()) != {"start", "end"}:
|
913 |
+
logger.error(
|
914 |
+
"If inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
915 |
+
)
|
916 |
+
raise
|
917 |
+
if self.inds_to_perturb["start"] < 0 or self.inds_to_perturb["end"] < 0:
|
918 |
+
logger.error(
|
919 |
+
'inds_to_perturb must be positive.'
|
920 |
+
)
|
921 |
+
raise
|
922 |
|
923 |
def perturb_data(self,
|
924 |
model_directory,
|
|
|
1017 |
cos_sims_dict = defaultdict(list)
|
1018 |
pickle_batch = -1
|
1019 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
1020 |
+
if self.inds_to_perturb != "all":
|
1021 |
+
if self.inds_to_perturb["start"] >= len(filtered_input_data):
|
1022 |
+
logger.error("inds_to_perturb['start'] is larger than the filtered dataset.")
|
1023 |
+
raise
|
1024 |
+
if self.inds_to_perturb["end"] > len(filtered_input_data):
|
1025 |
+
logger.warning("inds_to_perturb['end'] is larger than the filtered dataset. \
|
1026 |
+
Setting to the end of the filtered dataset.")
|
1027 |
+
self.inds_to_perturb["end"] = len(filtered_input_data)
|
1028 |
+
filtered_input_data = filtered_input_data.select([i for i in range(self.inds_to_perturb["start"], self.inds_to_perturb["end"])])
|
1029 |
|
1030 |
# make perturbation batch w/ single perturbation in multiple cells
|
1031 |
if self.perturb_group == True:
|