File size: 4,886 Bytes
8835e94
 
ca6ec69
 
8835e94
 
ca6ec69
 
 
 
 
8835e94
91c65e2
8437243
8835e94
0db4b4b
91c65e2
 
0db4b4b
 
 
 
 
 
8437243
0db4b4b
 
 
 
 
 
 
 
239d34a
0db4b4b
239d34a
0db4b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8835e94
 
 
0db4b4b
 
 
9be6ae0
 
 
 
 
0db4b4b
 
 
 
 
 
 
 
 
8835e94
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
---
license: cc-by-nc-sa-4.0
widget:
- text: ACCTGA<mask>TTCTGAGTC
datasets:
- InstaDeepAI/plant-genomic-benchmark
tags:
- biology
- genomics
- language model
- plants
---
## Model Overview
AgroNT is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling
objective to leverage highly available genotype data from 48 different plant speices to learn general representations of nucleotide sequences. AgroNT contains 1 billion parameters and has a context window of 1024 tokens. 
AgroNt uses a non-overlapping 6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs. 


## How to use
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch


model_name = 'agro-nucleotide-transformer-1b'

# fetch model and tokenizer from InstaDeep's hf repo
agro_nt_model = AutoModelForMaskedLM.from_pretrained(f'InstaDeepAI/{model_name}')
agro_nt_tokenizer = AutoTokenizer.from_pretrained(f'InstaDeepAI/{model_name}')

print(f"Loaded the {model_name} model with {agro_nt_model.num_parameters()} parameters and corresponding tokenizer.")

# example sequence and tokenization
sequences = ['ATATACGGCCGNC','GGGTATCGCTTCCGAC']

batch_tokens = agro_nt_tokenizer(sequences,padding="longest")['input_ids']
print(f"Tokenzied sequence: {agro_nt_tokenizer.batch_decode(batch_tokens)}")

torch_batch_tokens = torch.tensor(batch_tokens)
attention_mask = torch_batch_tokens != agro_nt_tokenizer.pad_token_id

# inference
outs = agro_nt_model(
    torch_batch_tokens,
    attention_mask=attention_mask,
    encoder_attention_mask=attention_mask,
    output_hidden_states=True
)

# get the final layer embeddings and language model head logits
embeddings = outs['hidden_states'][-1].detach().numpy()
logits = outs['logits'].detach().numpy()
```


## Pre-training

#### Data
Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database. 
The dataset consists of approximately 10.5 million genomic sequences across 48 different species.

#### Processing
 All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. A tokenizer was used to convert strings of letters into sequences of tokens. 
 The tokenizer's alphabet consisted of the 4<sup>6</sup> = 4096 possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens 
 representing standalone A, T, C, G, and N. It also included three special tokens: the pad [PAD], mask [MASK], and class [CLS] tokens. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and 
 then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter 
 N was present or if the sequence length was not a multiple of 6).

 **Tokenization example**
 
 nucleotide sequence:  ```ATCCCGGNNTCGACACN```\
 tokens:  ```<CLS> <ATCCCG> <G> <N> <N> <TCGACA> <C> <N>```

#### Training
The MLM objective was used to pre-train AgroNT in a self-supervised manner. In a self-supervised learning setting annotations (supervision) for each sequence 
are not needed as we can mask some proportion of the sequence and use the information contained in the unmasked portion of the sequence to predict the masked locations.
This allows us to leverage the vast amount of unlabeled genomic sequencing data available. Specifically, 15\% of the tokens in the input sequence are selected to be 
augmented with 80\% being replaced with a mask token, 10\% randomly replaced by another token from the vocabulary, and the final 10\% maintaining the same token. 
The tokenized sequence is passed through the model and a cross entropy loss is computed for the masked tokens. Pre-training was carried out with a sequence length of 1024 tokens 
and an effective batch size of 1.5M tokens for 315k update steps, resulting in the model training on a total of 472.5B tokens.

#### Hardware
Model pre-training was carried out using Google TPU-V4 accelerators, specifically a TPU v4-1024 containing 512 devices. We trained for a total of approx. four days.

### BibTeX entry and citation info
```bibtex
@article{mendoza2023foundational,
  title={A Foundational Large Language Model for Edible Plant Genomes},
  author={Mendoza-Revilla, Javier and Trop, Evan and Gonzalez, Liam and Roller, Masa and Dalla-Torre, Hugo and de Almeida, Bernardo P and Richard, Guillaume and Caton, Jonathan and Lopez Carranza, Nicolas and Skwark, Marcin and others},
  journal={bioRxiv},
  pages={2023--10},
  year={2023},
  publisher={Cold Spring Harbor Laboratory}
}
```