Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel | |
biases = False | |
class Pool2BN(nn.Module): | |
def __init__(self, num_channels): | |
super().__init__() | |
self.bn = torch.nn.BatchNorm1d(num_channels * 2) | |
def forward(self, x): | |
avgp = torch.nn.functional.adaptive_avg_pool1d(x, 1)[:, :, 0] | |
maxp = torch.nn.functional.adaptive_max_pool1d(x, 1)[:, :, 0] | |
x = torch.cat((avgp, maxp), axis=1) | |
x = self.bn(x) | |
return x | |
class MLP(torch.nn.Module): | |
def __init__(self, layer_sizes, biases=False, sigmoid=False, dropout=None): | |
super().__init__() | |
layers = [] | |
prev_size = layer_sizes[0] | |
for i, s in enumerate(layer_sizes[1:]): | |
if i != 0 and dropout is not None: | |
layers.append(torch.nn.Dropout(dropout)) | |
layers.append(torch.nn.Linear(in_features=prev_size, out_features=s, bias=biases)) | |
if i != len(layer_sizes) - 2: | |
if sigmoid: | |
# layers.append(torch.nn.Sigmoid()) | |
layers.append(torch.nn.Tanh()) | |
else: | |
layers.append(torch.nn.ReLU()) | |
layers.append(torch.nn.BatchNorm1d(s)) | |
prev_size = s | |
self.mlp = torch.nn.Sequential(*layers) | |
def forward(self, x): | |
return self.mlp(x) | |
class SimpleCNN(torch.nn.Module): | |
def __init__(self, k, num_filters, sigmoid=False, additional_layer=False): | |
super(SimpleCNN, self).__init__() | |
self.sigmoid = sigmoid | |
self.cnn = torch.nn.Conv1d(in_channels=4, out_channels=num_filters, kernel_size=k, bias=biases) | |
self.additional_layer = additional_layer | |
if additional_layer: | |
self.bn = nn.BatchNorm1d(num_filters) | |
# self.do = nn.Dropout(0.5) | |
self.cnn2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=1, bias=biases) | |
self.post = Pool2BN(num_filters) | |
def forward(self, x): | |
x = self.cnn(x) | |
x = (torch.tanh if self.sigmoid else torch.relu)(x) | |
if self.additional_layer: | |
x = self.bn(x) | |
# x = self.do(x) | |
x = self.cnn2(x) | |
x = (torch.tanh if self.sigmoid else torch.relu)(x) | |
x = self.post(x) | |
#print(f'x shape at CNN output: {x.shape}') | |
return x | |
class ResNet1dBlock(torch.nn.Module): | |
def __init__(self, num_filters, k1, internal_filters, k2, dropout=None, dilation=None): | |
super().__init__() | |
self.init_do = torch.nn.Dropout(dropout) if dropout is not None else None | |
self.bn1 = torch.nn.BatchNorm1d(num_filters) | |
if dilation is None: | |
dilation = 1 | |
self.cnn1 = torch.nn.Conv1d(in_channels=num_filters, out_channels=internal_filters, kernel_size=k1, bias=biases, | |
dilation=dilation, | |
padding=(k1 // 2) * dilation) | |
self.bn2 = torch.nn.BatchNorm1d(internal_filters) | |
self.cnn2 = torch.nn.Conv1d(in_channels=internal_filters, out_channels=num_filters, kernel_size=k2, bias=biases, | |
padding=k2 // 2) | |
def forward(self, x): | |
x_orig = x | |
x = self.bn1(x) | |
x = torch.relu(x) | |
if self.init_do is not None: | |
x = self.init_do(x) | |
x = self.cnn1(x) | |
x = self.bn2(x) | |
x = torch.relu(x) | |
x = self.cnn2(x) | |
return x + x_orig | |
class ResNet1d(torch.nn.Module): | |
def __init__(self, num_filters, block_spec, dropout=None, dilation=None): | |
super().__init__() | |
blocks = [ResNet1dBlock(num_filters, *spec, dropout=dropout, dilation=dilation) for spec in block_spec] | |
self.blocks = torch.nn.Sequential(*blocks) | |
def forward(self, x): | |
return self.blocks(x) | |
class LogisticRegressionTorch(nn.Module): | |
def __init__(self, input_dim: int, output_dim: int): | |
super(LogisticRegressionTorch, self).__init__() | |
self.batch_norm = nn.BatchNorm1d(num_features=input_dim) | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
x = self.batch_norm(x) | |
out = self.linear(x) | |
return out | |
class BertClassifier(nn.Module): | |
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int): | |
super(BertClassifier, self).__init__() | |
self.bert = bert_model | |
self.classifier = classifier | |
self.num_labels = num_labels | |
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None): | |
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True) | |
pooled_output = outputs.hidden_states[-1][:, 0, :] | |
logits = self.classifier(pooled_output) | |
return logits |