|
from torch import nn |
|
from transformers import PreTrainedModel, AutoModel, AutoConfig |
|
|
|
from .rna_torsionbert_config import RNATorsionBertConfig |
|
|
|
|
|
class RNATorsionBERTModel(PreTrainedModel): |
|
config_class = RNATorsionBertConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.init_model(config.k) |
|
self.dnabert = AutoModel.from_pretrained( |
|
self.model_name, config=self.dnabert_config, trust_remote_code=True |
|
) |
|
self.regressor = nn.Sequential( |
|
nn.LayerNorm(self.dnabert_config.hidden_size), |
|
nn.Linear(self.dnabert_config.hidden_size, config.hidden_size), |
|
nn.GELU(), |
|
nn.Linear(config.hidden_size, config.num_classes), |
|
) |
|
self.activation = nn.Tanh() |
|
|
|
def init_model(self, k: int): |
|
model_name = f"zhihan1996/DNA_bert_{k}" |
|
revisions = {3: "ed28178", 4: "c8499f0", 5: "c296157", 6: "a79a8fd"} |
|
dnabert_config = AutoConfig.from_pretrained( |
|
model_name, |
|
revision=revisions[k], |
|
trust_remote_code=True, |
|
) |
|
self.dnabert_config = dnabert_config |
|
self.model_name = model_name |
|
|
|
def forward(self, tensor): |
|
z = self.dnabert(**tensor).last_hidden_state |
|
output = self.regressor(z) |
|
output = self.activation(output) |
|
return {"logits": output} |