hchen725 commited on
Commit
5197a60
1 Parent(s): 47341ab

Update geneformer/tokenizer.py

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +49 -37
geneformer/tokenizer.py CHANGED
@@ -1,37 +1,23 @@
1
  """
2
  Geneformer tokenizer.
3
-
4
  **Input data:**
5
-
6
  | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
7
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
8
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
9
-
10
  | *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
11
  | *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
12
-
13
  **Usage:**
14
-
15
  .. code-block :: python
16
-
17
  >>> from geneformer import TranscriptomeTokenizer
18
  >>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
19
  >>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
20
-
21
  **Description:**
22
-
23
  | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
24
-
25
  | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
26
-
27
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
28
-
29
  | No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
30
-
31
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
32
-
33
  | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
34
-
35
  """
36
 
37
  from __future__ import annotations
@@ -39,7 +25,6 @@ from __future__ import annotations
39
  import os
40
  import logging
41
  import pickle
42
- import sys
43
  import warnings
44
  from pathlib import Path
45
  from typing import Literal
@@ -50,7 +35,6 @@ import numpy as np
50
  import scanpy as sc
51
  import loompy as lp
52
  import pandas as pd
53
- import anndata as ad
54
  import scipy.sparse as sp
55
  from datasets import Dataset
56
 
@@ -82,22 +66,38 @@ def tokenize_cell(gene_vector, gene_tokens):
82
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
83
 
84
  def sum_ensembl_ids(data_directory,
 
85
  gene_mapping_dict,
 
86
  file_format = "loom",
87
  chunk_size = 512):
 
88
  if file_format == "loom":
89
  """
90
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
91
  """
92
  with lp.connect(data_directory) as data:
93
  assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
 
 
 
 
 
 
 
 
 
 
 
94
  gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
95
-
96
- if len(set(gene_ids_collapsed)) == len(set(data.ra.ensembl_id)):
 
97
  return data_directory
98
 
99
  else:
100
  dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
 
101
  dup_genes = [idx for idx, count in Counter(data.ra["ensembl_id"]).items() if count > 1]
102
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
103
  first_chunk = True
@@ -108,10 +108,10 @@ def sum_ensembl_ids(data_directory,
108
  dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
109
  summed_data = dup_data_df.groupby(dup_data_df.index).sum()
110
  if not summed_data.index.is_unique:
111
- raise ValueError("Error: summed data frame non-unique.")
112
  data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
113
  if not data_count_view.index.is_unique:
114
- raise ValueError("Error: final data frame non-unique.")
115
  return data_count_view
116
  processed_chunk = process_chunk(view[:, :], dup_genes)
117
  processed_array = processed_chunk.to_numpy()
@@ -144,10 +144,20 @@ def sum_ensembl_ids(data_directory,
144
  data = sc.read_h5ad(str(data_directory))
145
 
146
  assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
 
 
 
 
 
 
 
 
 
 
147
 
148
  gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
149
-
150
- if len(set(gene_ids_collapsed)) == len(set(data.var.ensembl_id)):
151
  return data
152
 
153
  else:
@@ -199,15 +209,14 @@ class TranscriptomeTokenizer:
199
  chunk_size=512,
200
  model_input_size=2048,
201
  special_token=False,
 
202
  gene_median_file=GENE_MEDIAN_FILE,
203
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
204
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
205
  ):
206
  """
207
  Initialize tokenizer.
208
-
209
  **Parameters:**
210
-
211
  custom_attr_name_dict : None, dict
212
  | Dictionary of custom attributes to be added to the dataset.
213
  | Keys are the names of the attributes in the loom file.
@@ -220,16 +229,15 @@ class TranscriptomeTokenizer:
220
  | Max input size of model to truncate input to.
221
  special_token : bool = False
222
  | Adds CLS token before and EOS token after rank value encoding.
223
- collapse_gene_ids : bool = False
224
  | Whether to collapse gene IDs based on gene mapping dictionary.
225
  gene_median_file : Path
226
  | Path to pickle file containing dictionary of non-zero median
227
  | gene expression values across Genecorpus-30M.
228
  token_dictionary_file : Path
229
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
230
- gene_mapping_file : Path
231
  | Path to pickle file containing dictionary for collapsing gene IDs.
232
-
233
  """
234
  # dictionary of custom attributes {output dataset column name: input .loom column name}
235
  self.custom_attr_name_dict = custom_attr_name_dict
@@ -255,9 +263,15 @@ class TranscriptomeTokenizer:
255
  with open(token_dictionary_file, "rb") as f:
256
  self.gene_token_dict = pickle.load(f)
257
 
 
 
 
258
  # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
259
- with open(gene_mapping_file, "rb") as f:
260
- self.gene_mapping_dict = pickle.load(f)
 
 
 
261
 
262
  # gene keys for full vocabulary
263
  self.gene_keys = list(self.gene_token_dict.keys())
@@ -275,9 +289,7 @@ class TranscriptomeTokenizer:
275
  ):
276
  """
277
  Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
278
-
279
  **Parameters:**
280
-
281
  data_directory : Path
282
  | Path to directory containing loom files or anndata files
283
  output_directory : Path
@@ -288,7 +300,6 @@ class TranscriptomeTokenizer:
288
  | Format of input files. Can be "loom" or "h5ad".
289
  use_generator : bool
290
  | Whether to use generator or dict for tokenization.
291
-
292
  """
293
  tokenized_cells, cell_metadata = self.tokenize_files(
294
  Path(data_directory), file_format
@@ -339,7 +350,7 @@ class TranscriptomeTokenizer:
339
  return tokenized_cells, cell_metadata
340
 
341
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
342
- adata = sum_ensembl_ids(adata_file_path, self.gene_mapping_dict, file_format = "h5ad", chunk_size = self.chunk_size)
343
 
344
  if self.custom_attr_name_dict is not None:
345
  file_cell_metadata = {
@@ -406,7 +417,8 @@ class TranscriptomeTokenizer:
406
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
407
  }
408
 
409
- loom_file_path = sum_ensembl_ids(loom_file_path, self.gene_mapping_dict, file_format = "loom", chunk_size = self.chunk_size)
 
410
 
411
  with lp.connect(str(loom_file_path)) as data:
412
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
@@ -469,8 +481,8 @@ class TranscriptomeTokenizer:
469
  else:
470
  file_cell_metadata = None
471
 
472
- if "__dedup" in str(loom_file_path):
473
- os.remove(str(loom_file_path))
474
 
475
  return tokenized_cells, file_cell_metadata
476
 
@@ -527,4 +539,4 @@ class TranscriptomeTokenizer:
527
  output_dataset_truncated = output_dataset.map(
528
  format_cell_features, num_proc=self.nproc
529
  )
530
- return output_dataset_truncated
 
1
  """
2
  Geneformer tokenizer.
 
3
  **Input data:**
 
4
  | *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
5
  | *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
6
  | *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
 
7
  | *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
8
  | *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
 
9
  **Usage:**
 
10
  .. code-block :: python
 
11
  >>> from geneformer import TranscriptomeTokenizer
12
  >>> tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ"}, nproc=4)
13
  >>> tk.tokenize_data("data_directory", "output_directory", "output_prefix")
 
14
  **Description:**
 
15
  | Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.
 
16
  | The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
 
17
  | Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
 
18
  | No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively, the following custom attribute dictionary should be provided: {"cell_type": "cell_type", "organ_major": "organ"}.
 
19
  | Additionally, if the original .loom file contains a cell column attribute called "filter_pass", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with "1" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.
 
20
  | If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer.
 
21
  """
22
 
23
  from __future__ import annotations
 
25
  import os
26
  import logging
27
  import pickle
 
28
  import warnings
29
  from pathlib import Path
30
  from typing import Literal
 
35
  import scanpy as sc
36
  import loompy as lp
37
  import pandas as pd
 
38
  import scipy.sparse as sp
39
  from datasets import Dataset
40
 
 
66
  return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
67
 
68
  def sum_ensembl_ids(data_directory,
69
+ collapse_gene_ids,
70
  gene_mapping_dict,
71
+ gene_token_dict,
72
  file_format = "loom",
73
  chunk_size = 512):
74
+
75
  if file_format == "loom":
76
  """
77
  Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
78
  """
79
  with lp.connect(data_directory) as data:
80
  assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
81
+ gene_ids_in_dict = [gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()]
82
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
83
+ token_genes_unique = True
84
+ else:
85
+ token_genes_unique = False
86
+ if collapse_gene_ids is False:
87
+ if token_genes_unique:
88
+ return data_directory
89
+ else:
90
+ raise ValueError("Error: data Ensembl IDs non-unique.")
91
+
92
  gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
93
+ gene_ids_collapsed_in_dict = [gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()]
94
+
95
+ if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
96
  return data_directory
97
 
98
  else:
99
  dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
100
+ data.ra["ensembl_id"] = [gene_mapping_dict.get(gene_id, gene_id) for gene_id in data.ra.ensembl_id]
101
  dup_genes = [idx for idx, count in Counter(data.ra["ensembl_id"]).items() if count > 1]
102
  num_chunks = int(np.ceil(data.shape[1] / chunk_size))
103
  first_chunk = True
 
108
  dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
109
  summed_data = dup_data_df.groupby(dup_data_df.index).sum()
110
  if not summed_data.index.is_unique:
111
+ raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
112
  data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
113
  if not data_count_view.index.is_unique:
114
+ raise ValueError("Error: Ensembl IDs in final data frame non-unique.")
115
  return data_count_view
116
  processed_chunk = process_chunk(view[:, :], dup_genes)
117
  processed_array = processed_chunk.to_numpy()
 
144
  data = sc.read_h5ad(str(data_directory))
145
 
146
  assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
147
+ gene_ids_in_dict = [gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()]
148
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
149
+ token_genes_unique = True
150
+ else:
151
+ token_genes_unique = False
152
+ if collapse_gene_ids is False:
153
+ if token_genes_unique:
154
+ return data
155
+ else:
156
+ raise ValueError("Error: data Ensembl IDs non-unique.")
157
 
158
  gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
159
+ gene_ids_collapsed_in_dict = [gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()]
160
+ if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
161
  return data
162
 
163
  else:
 
209
  chunk_size=512,
210
  model_input_size=2048,
211
  special_token=False,
212
+ collapse_gene_ids=True,
213
  gene_median_file=GENE_MEDIAN_FILE,
214
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
215
  gene_mapping_file=ENSEMBL_MAPPING_FILE,
216
  ):
217
  """
218
  Initialize tokenizer.
 
219
  **Parameters:**
 
220
  custom_attr_name_dict : None, dict
221
  | Dictionary of custom attributes to be added to the dataset.
222
  | Keys are the names of the attributes in the loom file.
 
229
  | Max input size of model to truncate input to.
230
  special_token : bool = False
231
  | Adds CLS token before and EOS token after rank value encoding.
232
+ collapse_gene_ids : bool = True
233
  | Whether to collapse gene IDs based on gene mapping dictionary.
234
  gene_median_file : Path
235
  | Path to pickle file containing dictionary of non-zero median
236
  | gene expression values across Genecorpus-30M.
237
  token_dictionary_file : Path
238
  | Path to pickle file containing token dictionary (Ensembl IDs:token).
239
+ gene_mapping_file : None, Path
240
  | Path to pickle file containing dictionary for collapsing gene IDs.
 
241
  """
242
  # dictionary of custom attributes {output dataset column name: input .loom column name}
243
  self.custom_attr_name_dict = custom_attr_name_dict
 
263
  with open(token_dictionary_file, "rb") as f:
264
  self.gene_token_dict = pickle.load(f)
265
 
266
+ # if collapsing duplicate gene IDs
267
+ self.collapse_gene_ids = collapse_gene_ids
268
+
269
  # load gene mappings dictionary (Ensembl IDs:Ensembl ID)
270
+ if gene_mapping_file is not None:
271
+ with open(gene_mapping_file, "rb") as f:
272
+ self.gene_mapping_dict = pickle.load(f)
273
+ else:
274
+ self.gene_mapping_dict = {k:k for k,_ in self.gene_token_dict.items()}
275
 
276
  # gene keys for full vocabulary
277
  self.gene_keys = list(self.gene_token_dict.keys())
 
289
  ):
290
  """
291
  Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
 
292
  **Parameters:**
 
293
  data_directory : Path
294
  | Path to directory containing loom files or anndata files
295
  output_directory : Path
 
300
  | Format of input files. Can be "loom" or "h5ad".
301
  use_generator : bool
302
  | Whether to use generator or dict for tokenization.
 
303
  """
304
  tokenized_cells, cell_metadata = self.tokenize_files(
305
  Path(data_directory), file_format
 
350
  return tokenized_cells, cell_metadata
351
 
352
  def tokenize_anndata(self, adata_file_path, target_sum=10_000):
353
+ adata = sum_ensembl_ids(adata_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, file_format = "h5ad", chunk_size = self.chunk_size)
354
 
355
  if self.custom_attr_name_dict is not None:
356
  file_cell_metadata = {
 
417
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
418
  }
419
 
420
+ dedup_filename = loom_file_path.with_name(loom_file_path.stem + "__dedup.loom")
421
+ loom_file_path = sum_ensembl_ids(loom_file_path, self.collapse_gene_ids, self.gene_mapping_dict, self.gene_token_dict, file_format = "loom", chunk_size = self.chunk_size)
422
 
423
  with lp.connect(str(loom_file_path)) as data:
424
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
 
481
  else:
482
  file_cell_metadata = None
483
 
484
+ if str(dedup_filename) == str(loom_file_path):
485
+ os.remove(str(dedup_filename))
486
 
487
  return tokenized_cells, file_cell_metadata
488
 
 
539
  output_dataset_truncated = output_dataset.map(
540
  format_cell_features, num_proc=self.nproc
541
  )
542
+ return output_dataset_truncated