cnn8rnn-audioset-sed / hf_model.py
wsntxxn's picture
Upload Cnn8RnnSoundEventDetection
ef9c1d9 verified
raw
history blame
7.67 kB
import os
import torch
from torchaudio import transforms
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils.hub import cached_file
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
def interpolate(x, ratio):
"""Interpolate data in time domain. This is used to compensate the
resolution reduction in downsampling of a CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to interpolate
Returns:
upsampled: (batch_size, time_steps * ratio, classes_num)
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def pad_framewise_output(framewise_output, frames_num):
"""Pad framewise_output to the same length as input frames. The pad value
is the same as the value of the last frame.
Args:
framewise_output: (batch_size, frames_num, classes_num)
frames_num: int, number of frames to pad
Outputs:
output: (batch_size, frames_num, classes_num)
"""
pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
"""tensor for padding"""
output = torch.cat((framewise_output, pad), dim=1)
"""(batch_size, frames_num, classes_num)"""
return output
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=(3, 3), stride=(1, 1),
padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.init_weight()
def init_weight(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
x = input
x = F.relu_(self.bn1(self.conv1(x)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')
return x
class LinearSoftmax(nn.Module):
def __init__(self, pooldim=1):
super().__init__()
self.pooldim = pooldim
def forward(self, time_decision):
return (time_decision**2).sum(self.pooldim) / time_decision.sum(
self.pooldim)
class Cnn8RnnConfig(PretrainedConfig):
def __init__(
self,
classes_num: int = 447,
sample_rate: int = 32000,
**kwargs
):
self.classes_num = classes_num
self.sample_rate = sample_rate
super().__init__(**kwargs)
class Cnn8RnnSoundEventDetection(PreTrainedModel):
config_class = Cnn8RnnConfig
def __init__(self, config: Cnn8RnnConfig):
super().__init__(config)
self.config = config
self.time_resolution = 0.01
self.interpolate_ratio = 4 # Downsampled ratio
# Logmel spectrogram extractor
self.melspec_extractor = transforms.MelSpectrogram(
sample_rate=config.sample_rate,
n_fft=1024,
win_length=1024,
hop_length=320,
f_min=50,
f_max=14000,
n_mels=64,
norm="slaney",
mel_scale="slaney"
)
self.db_transform = transforms.AmplitudeToDB()
self.bn0 = nn.BatchNorm2d(64)
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
self.fc1 = nn.Linear(512, 512, bias=True)
self.rnn = nn.GRU(512, 256, bidirectional=True, batch_first=True)
self.fc_audioset = nn.Linear(512, config.classes_num, bias=True)
self.temporal_pooling = LinearSoftmax()
self.init_weight()
def init_weight(self):
init_bn(self.bn0)
init_layer(self.fc1)
init_layer(self.fc_audioset)
def forward(self, waveform):
x = self.melspec_extractor(waveform)
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
x = x.transpose(1, 2)
x = x.unsqueeze(1)
frames_num = x.shape[2]
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg+max')
x = F.dropout(x, p=0.2, training=self.training) # (batch_size, 256, time_steps / 4, mel_bins / 16)
x = torch.mean(x, dim=3)
x = x.transpose(1, 2)
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu_(self.fc1(x))
x, _ = self.rnn(x)
segmentwise_output = torch.sigmoid(self.fc_audioset(x)).clamp(1e-7, 1.)
clipwise_output = self.temporal_pooling(segmentwise_output)
# Get framewise output
framewise_output = interpolate(segmentwise_output,
self.interpolate_ratio)
framewise_output = pad_framewise_output(framewise_output, frames_num)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output
}
return output_dict
def save_pretrained(self, save_directory, *args, **kwargs):
super().save_pretrained(save_directory, *args, **kwargs)
with open(os.path.join(save_directory, "classes.txt"), "w") as f:
for class_name in self.classes:
f.write(class_name + "\n")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
model = super().from_pretrained(pretrained_model_name_or_path,
*model_args, **kwargs)
class_file = cached_file(pretrained_model_name_or_path, "classes.txt")
with open(class_file, "r") as f:
model.classes = [l.strip() for l in f]
return model