|
--- |
|
license: cc-by-nc-sa-4.0 |
|
widget: |
|
- text: ACCTGA<mask>TTCTGAGTC |
|
datasets: |
|
- InstaDeepAI/plant-genomic-benchmark |
|
tags: |
|
- biology |
|
- genomics |
|
- language model |
|
- plants |
|
--- |
|
## Model Overview |
|
AgroNT is a DNA language model trained on primarily edible plant genomes. More specifically, AgroNT uses the transformer architecture with self-attention and a masked language modeling |
|
objective to leverage highly available genotype data from 48 different plant speices to learn general representations of nucleotide sequences. AgroNT contains 1 billion parameters and has a context window of 1024 tokens. |
|
AgroNt uses a non-overlapping 6-mer tokenizer to convert genomic nucletoide sequences to tokens. As a result the 1024 tokens correspond to approximately 6144 base pairs. |
|
|
|
|
|
## How to use |
|
```python |
|
from transformers import AutoModelForMaskedLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
model_name = 'agro-nucleotide-transformer-1b' |
|
|
|
# fetch model and tokenizer from InstaDeep's hf repo |
|
agro_nt_model = AutoModelForMaskedLM.from_pretrained(f'InstaDeepAI/{model_name}') |
|
agro_nt_tokenizer = AutoTokenizer.from_pretrained(f'InstaDeepAI/{model_name}') |
|
|
|
print(f"Loaded the {model_name} model with {agro_nt_model.num_parameters()} parameters and corresponding tokenizer.") |
|
|
|
# example sequence and tokenization |
|
sequences = ['ATATACGGCCGNC','GGGTATCGCTTCCGAC'] |
|
|
|
batch_tokens = agro_nt_tokenizer(sequences,padding="longest")['input_ids'] |
|
print(f"Tokenzied sequence: {agro_nt_tokenizer.batch_decode(batch_tokens)}") |
|
|
|
torch_batch_tokens = torch.tensor(batch_tokens) |
|
attention_mask = torch_batch_tokens != agro_nt_tokenizer.pad_token_id |
|
|
|
# inference |
|
outs = agro_nt_model( |
|
torch_batch_tokens, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=attention_mask, |
|
output_hidden_states=True |
|
) |
|
|
|
# get the final layer embeddings and language model head logits |
|
embeddings = outs['hidden_states'][-1].detach().numpy() |
|
logits = outs['logits'].detach().numpy() |
|
``` |
|
|
|
|
|
## Pre-training |
|
|
|
#### Data |
|
Our pre-training dataset was built from (mostly) edible plants reference genomes contained in the Ensembl Plants database. |
|
The dataset consists of approximately 10.5 million genomic sequences across 48 different species. |
|
|
|
#### Processing |
|
All reference genomes for each specie were assembled into a single fasta file. In this fasta file, all nucleotides other than A, T, C, G were replaced by N. A tokenizer was used to convert strings of letters into sequences of tokens. |
|
The tokenizer's alphabet consisted of the 4<sup>6</sup> = 4096 possible 6-mer combinations obtained by combining A, T, C, G, as well as five additional tokens |
|
representing standalone A, T, C, G, and N. It also included three special tokens: the pad [PAD], mask [MASK], and class [CLS] tokens. This resulted in a vocabulary of 4104 tokens. To tokenize an input sequence, the tokenizer started with a class token and |
|
then converted the sequence from left to right, matching 6-mer tokens when possible, or using the standalone tokens when necessary (for instance, when the letter |
|
N was present or if the sequence length was not a multiple of 6). |
|
|
|
**Tokenization example** |
|
|
|
nucleotide sequence: ```ATCCCGGNNTCGACACN```\ |
|
tokens: ```<CLS> <ATCCCG> <G> <N> <N> <TCGACA> <C> <N>``` |
|
|
|
#### Training |
|
The MLM objective was used to pre-train AgroNT in a self-supervised manner. In a self-supervised learning setting annotations (supervision) for each sequence |
|
are not needed as we can mask some proportion of the sequence and use the information contained in the unmasked portion of the sequence to predict the masked locations. |
|
This allows us to leverage the vast amount of unlabeled genomic sequencing data available. Specifically, 15\% of the tokens in the input sequence are selected to be |
|
augmented with 80\% being replaced with a mask token, 10\% randomly replaced by another token from the vocabulary, and the final 10\% maintaining the same token. |
|
The tokenized sequence is passed through the model and a cross entropy loss is computed for the masked tokens. Pre-training was carried out with a sequence length of 1024 tokens |
|
and an effective batch size of 1.5M tokens for 315k update steps, resulting in the model training on a total of 472.5B tokens. |
|
|
|
#### Hardware |
|
Model pre-training was carried out using Google TPU-V4 accelerators, specifically a TPU v4-1024 containing 512 devices. We trained for a total of approx. four days. |
|
|
|
### BibTeX entry and citation info |
|
```bibtex |
|
@article{mendoza2023foundational, |
|
title={A Foundational Large Language Model for Edible Plant Genomes}, |
|
author={Mendoza-Revilla, Javier and Trop, Evan and Gonzalez, Liam and Roller, Masa and Dalla-Torre, Hugo and de Almeida, Bernardo P and Richard, Guillaume and Caton, Jonathan and Lopez Carranza, Nicolas and Skwark, Marcin and others}, |
|
journal={bioRxiv}, |
|
pages={2023--10}, |
|
year={2023}, |
|
publisher={Cold Spring Harbor Laboratory} |
|
} |
|
``` |