hchen725 commited on
Commit
b8b87fd
1 Parent(s): 1e8d481

Update geneformer/tokenizer.py

Browse files

- Add checks for CLS and EOS token when special_toke = True
- More efficient filter of gene_mapping_dict for values in gene_token_dict
- Remove summing of genes that do not exist in gene_token_dict for loom files

Files changed (1) hide show
  1. geneformer/tokenizer.py +14 -12
geneformer/tokenizer.py CHANGED
@@ -94,18 +94,17 @@ def sum_ensembl_ids(data_directory,
94
 
95
  if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
96
  return data_directory
97
-
98
  else:
99
  dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
100
- data.ra["ensembl_id"] = [gene_mapping_dict.get(gene_id, gene_id) for gene_id in data.ra.ensembl_id]
101
- dup_genes = [idx for idx, count in Counter(data.ra["ensembl_id"]).items() if count > 1]
102
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
103
  first_chunk = True
104
  for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
105
  def process_chunk(view, duplic_genes):
106
- data_count_view = pd.DataFrame(view, index=data.ra["ensembl_id"])
107
  unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
108
- dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
109
  summed_data = dup_data_df.groupby(dup_data_df.index).sum()
110
  if not summed_data.index.is_unique:
111
  raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
@@ -117,12 +116,6 @@ def sum_ensembl_ids(data_directory,
117
  processed_array = processed_chunk.to_numpy()
118
  new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
119
 
120
- ra_keys = [k for k in data.ra.keys() if k != "ensembl_id"]
121
- for ra_value in ra_keys:
122
- mapping_dict = dict(zip(data.ra["ensembl_id"], data.ra[ra_value]))
123
- values_new = [mapping_dict[i] for i in processed_chunk.index]
124
- new_row_attrs[ra_value] = np.array(values_new)
125
-
126
  if "n_counts" not in view.ca.keys():
127
  total_count_view = np.sum(view[:,:], axis=0).astype(int)
128
  view.ca["n_counts"] = total_count_view
@@ -263,6 +256,14 @@ class TranscriptomeTokenizer:
263
  with open(token_dictionary_file, "rb") as f:
264
  self.gene_token_dict = pickle.load(f)
265
 
 
 
 
 
 
 
 
 
266
  # if collapsing duplicate gene IDs
267
  self.collapse_gene_ids = collapse_gene_ids
268
 
@@ -277,7 +278,8 @@ class TranscriptomeTokenizer:
277
  self.gene_keys = list(self.gene_token_dict.keys())
278
 
279
  # Filter gene mapping dict for items that exist in gene_token_dict
280
- self.gene_mapping_dict = {k: v for k, v in self.gene_mapping_dict.items() if v in self.gene_keys}
 
281
 
282
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
283
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
 
94
 
95
  if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
96
  return data_directory
 
97
  else:
98
  dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
99
+ data.ra["gene_ids_collapsed"] = gene_ids_collapsed
100
+ dup_genes = [idx for idx, count in Counter(data.ra["gene_ids_collapsed"]).items() if count > 1]
101
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
102
  first_chunk = True
103
  for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
104
  def process_chunk(view, duplic_genes):
105
+ data_count_view = pd.DataFrame(view, index=data.ra["gene_ids_collapsed"])
106
  unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
107
+ dup_data_df = data_count_view.loc[data_count_view.index.isin([i for i in duplic_genes if "None" not in i])]
108
  summed_data = dup_data_df.groupby(dup_data_df.index).sum()
109
  if not summed_data.index.is_unique:
110
  raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
 
116
  processed_array = processed_chunk.to_numpy()
117
  new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
118
 
 
 
 
 
 
 
119
  if "n_counts" not in view.ca.keys():
120
  total_count_view = np.sum(view[:,:], axis=0).astype(int)
121
  view.ca["n_counts"] = total_count_view
 
256
  with open(token_dictionary_file, "rb") as f:
257
  self.gene_token_dict = pickle.load(f)
258
 
259
+ # check for special token in gene_token_dict
260
+ if self.special_token:
261
+ if ("<cls>" not in self.gene_token_dict.keys()) and ("<eos>" not in self.gene_token_dict.keys()):
262
+ logger.error(
263
+ "<cls> and <eos> required in gene_token_dict when special_token = True."
264
+ )
265
+ raise
266
+
267
  # if collapsing duplicate gene IDs
268
  self.collapse_gene_ids = collapse_gene_ids
269
 
 
278
  self.gene_keys = list(self.gene_token_dict.keys())
279
 
280
  # Filter gene mapping dict for items that exist in gene_token_dict
281
+ gene_keys_set = set(self.gene_token_dict.keys())
282
+ self.gene_mapping_dict = {k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set}
283
 
284
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
285
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))