zhihan1996 commited on
Commit
b054dd3
1 Parent(s): a794b19

Update dnabert_layer.py

Browse files
Files changed (1) hide show
  1. dnabert_layer.py +6 -0
dnabert_layer.py CHANGED
@@ -5,6 +5,7 @@ import torch.nn as nn
5
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
  from transformers.models.bert.modeling_bert import BertModel as TransformersBertModel
7
  from transformers.models.bert.modeling_bert import BertForMaskedLM as TransformersBertForMaskedLM
 
8
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
9
  from transformers.modeling_outputs import SequenceClassifierOutput
10
 
@@ -17,6 +18,11 @@ class BertForMaskedLM(TransformersBertForMaskedLM):
17
  def __init__(self, config):
18
  super().__init__(config)
19
 
 
 
 
 
 
20
 
21
  class DNABertForSequenceClassification(BertPreTrainedModel):
22
  def __init__(self, config):
 
5
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
  from transformers.models.bert.modeling_bert import BertModel as TransformersBertModel
7
  from transformers.models.bert.modeling_bert import BertForMaskedLM as TransformersBertForMaskedLM
8
+ from transformers.models.bert.modeling_bert import BertForPreTraining as TransformersBertForPreTraining
9
  from transformers.models.bert.modeling_bert import BertPreTrainedModel
10
  from transformers.modeling_outputs import SequenceClassifierOutput
11
 
 
18
  def __init__(self, config):
19
  super().__init__(config)
20
 
21
+ class BertForPreTraining(TransformersBertForPreTraining):
22
+ def __init__(self, config):
23
+ super().__init__(config)
24
+
25
+
26
 
27
  class DNABertForSequenceClassification(BertPreTrainedModel):
28
  def __init__(self, config):