NOOTestspace / model_archs
mawairon's picture
Update model_archs
1bb2663 verified
raw
history blame
4.8 kB
import torch
import torch.nn as nn
biases = False
class Pool2BN(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.bn = torch.nn.BatchNorm1d(num_channels * 2)
def forward(self, x):
avgp = torch.nn.functional.adaptive_avg_pool1d(x, 1)[:, :, 0]
maxp = torch.nn.functional.adaptive_max_pool1d(x, 1)[:, :, 0]
x = torch.cat((avgp, maxp), axis=1)
x = self.bn(x)
return x
class MLP(torch.nn.Module):
def __init__(self, layer_sizes, biases=False, sigmoid=False, dropout=None):
super().__init__()
layers = []
prev_size = layer_sizes[0]
for i, s in enumerate(layer_sizes[1:]):
if i != 0 and dropout is not None:
layers.append(torch.nn.Dropout(dropout))
layers.append(torch.nn.Linear(in_features=prev_size, out_features=s, bias=biases))
if i != len(layer_sizes) - 2:
if sigmoid:
# layers.append(torch.nn.Sigmoid())
layers.append(torch.nn.Tanh())
else:
layers.append(torch.nn.ReLU())
layers.append(torch.nn.BatchNorm1d(s))
prev_size = s
self.mlp = torch.nn.Sequential(*layers)
def forward(self, x):
return self.mlp(x)
class SimpleCNN(torch.nn.Module):
def __init__(self, k, num_filters, sigmoid=False, additional_layer=False):
super(SimpleCNN, self).__init__()
self.sigmoid = sigmoid
self.cnn = torch.nn.Conv1d(in_channels=4, out_channels=num_filters, kernel_size=k, bias=biases)
self.additional_layer = additional_layer
if additional_layer:
self.bn = nn.BatchNorm1d(num_filters)
# self.do = nn.Dropout(0.5)
self.cnn2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=1, bias=biases)
self.post = Pool2BN(num_filters)
def forward(self, x):
x = self.cnn(x)
x = (torch.tanh if self.sigmoid else torch.relu)(x)
if self.additional_layer:
x = self.bn(x)
# x = self.do(x)
x = self.cnn2(x)
x = (torch.tanh if self.sigmoid else torch.relu)(x)
x = self.post(x)
#print(f'x shape at CNN output: {x.shape}')
return x
class ResNet1dBlock(torch.nn.Module):
def __init__(self, num_filters, k1, internal_filters, k2, dropout=None, dilation=None):
super().__init__()
self.init_do = torch.nn.Dropout(dropout) if dropout is not None else None
self.bn1 = torch.nn.BatchNorm1d(num_filters)
if dilation is None:
dilation = 1
self.cnn1 = torch.nn.Conv1d(in_channels=num_filters, out_channels=internal_filters, kernel_size=k1, bias=biases,
dilation=dilation,
padding=(k1 // 2) * dilation)
self.bn2 = torch.nn.BatchNorm1d(internal_filters)
self.cnn2 = torch.nn.Conv1d(in_channels=internal_filters, out_channels=num_filters, kernel_size=k2, bias=biases,
padding=k2 // 2)
def forward(self, x):
x_orig = x
x = self.bn1(x)
x = torch.relu(x)
if self.init_do is not None:
x = self.init_do(x)
x = self.cnn1(x)
x = self.bn2(x)
x = torch.relu(x)
x = self.cnn2(x)
return x + x_orig
class ResNet1d(torch.nn.Module):
def __init__(self, num_filters, block_spec, dropout=None, dilation=None):
super().__init__()
blocks = [ResNet1dBlock(num_filters, *spec, dropout=dropout, dilation=dilation) for spec in block_spec]
self.blocks = torch.nn.Sequential(*blocks)
def forward(self, x):
return self.blocks(x)
class LogisticRegressionTorch(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super(LogisticRegressionTorch, self).__init__()
self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
x = self.batch_norm(x)
out = self.linear(x)
return out
class BertClassifier(nn.Module):
def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
super(BertClassifier, self).__init__()
self.bert = bert_model
self.classifier = classifier
self.num_labels = num_labels
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
pooled_output = outputs.hidden_states[-1][:, 0, :]
logits = self.classifier(pooled_output)
return logits