import torch import torch.nn as nn import torchvision from monotonenorm import GroupSort, direct_norm # GroupSort Neural Networks created using monotonenorm package. #See https://github.com/niklasnolte/MonotoneNorm for more information class NN_phi(nn.Module): def __init__(self,dim_input): super(NN_phi, self).__init__() self.linear1=direct_norm(torch.nn.Linear(dim_input,16),kind="two-inf") self.group1=GroupSort(16//2)#GroupSort with a grouping size of 2 self.linear2=direct_norm(torch.nn.Linear(16,32),kind="inf") self.group2=GroupSort(32//2) self.linear3=direct_norm(torch.nn.Linear(32,1),kind="inf") def forward(self, x): x=self.linear1(x) x=self.group1(x) x=self.linear2(x) x=self.group2(x) x=self.linear3(x) return x class NN_h_RELU(nn.Module): def __init__(self,dim_input): super(NN_h_RELU, self).__init__() self.linear1=torch.nn.Linear(dim_input,16) self.RELU=torch.nn.ReLU() self.linear2=torch.nn.Linear(16,32) self.linear3=torch.nn.Linear(32,1) def forward(self, x): x=self.linear1(x) x=self.RELU(x) x=self.linear2(x) x=self.RELU(x) x=self.linear3(x) return x