Spaces:
Build error
Build error
Upload 46 files
Browse files- audio_to_text/__init__.py +0 -0
- audio_to_text/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_to_text/__pycache__/inference_waveform.cpython-38.pyc +0 -0
- audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml +23 -0
- audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth +3 -0
- audio_to_text/captioning/__init__.py +0 -0
- audio_to_text/captioning/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__init__.py +3 -0
- audio_to_text/captioning/models/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/attn_model.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/base_model.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/decoder.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/encoder.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/fc_model.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/rl_model.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/style_model.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/transformer_model.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/__pycache__/utils.cpython-38.pyc +0 -0
- audio_to_text/captioning/models/base_model.py +500 -0
- audio_to_text/captioning/models/decoder.py +746 -0
- audio_to_text/captioning/models/encoder.py +686 -0
- audio_to_text/captioning/models/transformer_model.py +265 -0
- audio_to_text/captioning/models/utils.py +132 -0
- audio_to_text/captioning/utils/README.md +19 -0
- audio_to_text/captioning/utils/__init__.py +0 -0
- audio_to_text/captioning/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- audio_to_text/captioning/utils/__pycache__/train_util.cpython-38.pyc +0 -0
- audio_to_text/captioning/utils/bert/create_sent_embedding.py +89 -0
- audio_to_text/captioning/utils/bert/create_word_embedding.py +34 -0
- audio_to_text/captioning/utils/build_vocab.py +153 -0
- audio_to_text/captioning/utils/build_vocab_ltp.py +150 -0
- audio_to_text/captioning/utils/build_vocab_spacy.py +152 -0
- audio_to_text/captioning/utils/eval_round_robin.py +182 -0
- audio_to_text/captioning/utils/fasttext/create_word_embedding.py +50 -0
- audio_to_text/captioning/utils/lr_scheduler.py +128 -0
- audio_to_text/captioning/utils/model_eval_diff.py +110 -0
- audio_to_text/captioning/utils/predict_nn.py +49 -0
- audio_to_text/captioning/utils/remove_optimizer.py +18 -0
- audio_to_text/captioning/utils/report_results.py +37 -0
- audio_to_text/captioning/utils/tokenize_caption.py +86 -0
- audio_to_text/captioning/utils/train_util.py +178 -0
- audio_to_text/captioning/utils/word2vec/create_word_embedding.py +67 -0
- audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml +22 -0
- audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth +3 -0
- audio_to_text/inference_waveform.py +102 -0
- audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth +3 -0
audio_to_text/__init__.py
ADDED
File without changes
|
audio_to_text/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (157 Bytes). View file
|
|
audio_to_text/__pycache__/inference_waveform.cpython-38.pyc
ADDED
Binary file (3.01 kB). View file
|
|
audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
encoder:
|
3 |
+
type: Cnn14RnnEncoder
|
4 |
+
args:
|
5 |
+
sample_rate: 32000
|
6 |
+
pretrained: ./audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
|
7 |
+
freeze_cnn: True
|
8 |
+
freeze_cnn_bn: True
|
9 |
+
bidirectional: True
|
10 |
+
dropout: 0.5
|
11 |
+
hidden_size: 256
|
12 |
+
num_layers: 3
|
13 |
+
decoder:
|
14 |
+
type: TransformerDecoder
|
15 |
+
args:
|
16 |
+
attn_emb_dim: 512
|
17 |
+
dropout: 0.2
|
18 |
+
emb_dim: 256
|
19 |
+
fc_emb_dim: 512
|
20 |
+
nlayers: 2
|
21 |
+
type: TransformerModel
|
22 |
+
args: {}
|
23 |
+
|
audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d22099e1025baae0f32ce09ec02c3d5fea001e295512fbf8754b5c66db21b0ec
|
3 |
+
size 43027289
|
audio_to_text/captioning/__init__.py
ADDED
File without changes
|
audio_to_text/captioning/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (168 Bytes). View file
|
|
audio_to_text/captioning/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .base_model import *
|
2 |
+
from .transformer_model import *
|
3 |
+
|
audio_to_text/captioning/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (231 Bytes). View file
|
|
audio_to_text/captioning/models/__pycache__/attn_model.cpython-38.pyc
ADDED
Binary file (7.73 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/base_model.cpython-38.pyc
ADDED
Binary file (15.8 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (19.1 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (19.4 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/fc_model.cpython-38.pyc
ADDED
Binary file (3.5 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/rl_model.cpython-38.pyc
ADDED
Binary file (2.19 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/style_model.cpython-38.pyc
ADDED
Binary file (3.4 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/transformer_model.cpython-38.pyc
ADDED
Binary file (7.6 kB). View file
|
|
audio_to_text/captioning/models/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.16 kB). View file
|
|
audio_to_text/captioning/models/base_model.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .utils import mean_with_lens, repeat_tensor
|
9 |
+
|
10 |
+
|
11 |
+
class CaptionModel(nn.Module):
|
12 |
+
"""
|
13 |
+
Encoder-decoder captioning model.
|
14 |
+
"""
|
15 |
+
|
16 |
+
pad_idx = 0
|
17 |
+
start_idx = 1
|
18 |
+
end_idx = 2
|
19 |
+
max_length = 20
|
20 |
+
|
21 |
+
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
22 |
+
super().__init__()
|
23 |
+
self.encoder = encoder
|
24 |
+
self.decoder = decoder
|
25 |
+
self.vocab_size = decoder.vocab_size
|
26 |
+
self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
|
27 |
+
self.inference_forward_keys = ["sample_method", "max_length", "temp"]
|
28 |
+
freeze_encoder = kwargs.get("freeze_encoder", False)
|
29 |
+
if freeze_encoder:
|
30 |
+
for param in self.encoder.parameters():
|
31 |
+
param.requires_grad = False
|
32 |
+
self.check_decoder_compatibility()
|
33 |
+
|
34 |
+
def check_decoder_compatibility(self):
|
35 |
+
compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
|
36 |
+
assert isinstance(self.decoder, self.compatible_decoders), \
|
37 |
+
f"{self.decoder.__class__.__name__} is incompatible with " \
|
38 |
+
f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def set_index(cls, start_idx, end_idx):
|
42 |
+
cls.start_idx = start_idx
|
43 |
+
cls.end_idx = end_idx
|
44 |
+
|
45 |
+
def forward(self, input_dict: Dict):
|
46 |
+
"""
|
47 |
+
input_dict: {
|
48 |
+
(required)
|
49 |
+
mode: train/inference,
|
50 |
+
spec,
|
51 |
+
spec_len,
|
52 |
+
fc,
|
53 |
+
attn,
|
54 |
+
attn_len,
|
55 |
+
[sample_method: greedy],
|
56 |
+
[temp: 1.0] (in case of no teacher forcing)
|
57 |
+
|
58 |
+
(optional, mode=train)
|
59 |
+
cap,
|
60 |
+
cap_len,
|
61 |
+
ss_ratio,
|
62 |
+
|
63 |
+
(optional, mode=inference)
|
64 |
+
sample_method: greedy/beam,
|
65 |
+
max_length,
|
66 |
+
temp,
|
67 |
+
beam_size (optional, sample_method=beam),
|
68 |
+
n_best (optional, sample_method=beam),
|
69 |
+
}
|
70 |
+
"""
|
71 |
+
# encoder_input_keys = ["spec", "spec_len", "fc", "attn", "attn_len"]
|
72 |
+
# encoder_input = { key: input_dict[key] for key in encoder_input_keys }
|
73 |
+
encoder_output_dict = self.encoder(input_dict)
|
74 |
+
if input_dict["mode"] == "train":
|
75 |
+
forward_dict = {
|
76 |
+
"mode": "train", "sample_method": "greedy", "temp": 1.0
|
77 |
+
}
|
78 |
+
for key in self.train_forward_keys:
|
79 |
+
forward_dict[key] = input_dict[key]
|
80 |
+
forward_dict.update(encoder_output_dict)
|
81 |
+
output = self.train_forward(forward_dict)
|
82 |
+
elif input_dict["mode"] == "inference":
|
83 |
+
forward_dict = {"mode": "inference"}
|
84 |
+
default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
|
85 |
+
for key in self.inference_forward_keys:
|
86 |
+
if key in input_dict:
|
87 |
+
forward_dict[key] = input_dict[key]
|
88 |
+
else:
|
89 |
+
forward_dict[key] = default_args[key]
|
90 |
+
|
91 |
+
if forward_dict["sample_method"] == "beam":
|
92 |
+
forward_dict["beam_size"] = input_dict.get("beam_size", 3)
|
93 |
+
forward_dict["n_best"] = input_dict.get("n_best", False)
|
94 |
+
forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
|
95 |
+
elif forward_dict["sample_method"] == "dbs":
|
96 |
+
forward_dict["beam_size"] = input_dict.get("beam_size", 6)
|
97 |
+
forward_dict["group_size"] = input_dict.get("group_size", 3)
|
98 |
+
forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
|
99 |
+
forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
|
100 |
+
|
101 |
+
forward_dict.update(encoder_output_dict)
|
102 |
+
output = self.inference_forward(forward_dict)
|
103 |
+
else:
|
104 |
+
raise Exception("mode should be either 'train' or 'inference'")
|
105 |
+
|
106 |
+
return output
|
107 |
+
|
108 |
+
def prepare_output(self, input_dict):
|
109 |
+
output = {}
|
110 |
+
batch_size = input_dict["fc_emb"].size(0)
|
111 |
+
if input_dict["mode"] == "train":
|
112 |
+
max_length = input_dict["cap"].size(1) - 1
|
113 |
+
elif input_dict["mode"] == "inference":
|
114 |
+
max_length = input_dict["max_length"]
|
115 |
+
else:
|
116 |
+
raise Exception("mode should be either 'train' or 'inference'")
|
117 |
+
device = input_dict["fc_emb"].device
|
118 |
+
output["seq"] = torch.full((batch_size, max_length), self.end_idx,
|
119 |
+
dtype=torch.long)
|
120 |
+
output["logit"] = torch.empty(batch_size, max_length,
|
121 |
+
self.vocab_size).to(device)
|
122 |
+
output["sampled_logprob"] = torch.zeros(batch_size, max_length)
|
123 |
+
output["embed"] = torch.empty(batch_size, max_length,
|
124 |
+
self.decoder.d_model).to(device)
|
125 |
+
return output
|
126 |
+
|
127 |
+
def train_forward(self, input_dict):
|
128 |
+
if input_dict["ss_ratio"] != 1: # scheduled sampling training
|
129 |
+
input_dict["mode"] = "train"
|
130 |
+
return self.stepwise_forward(input_dict)
|
131 |
+
output = self.seq_forward(input_dict)
|
132 |
+
self.train_process(output, input_dict)
|
133 |
+
return output
|
134 |
+
|
135 |
+
def seq_forward(self, input_dict):
|
136 |
+
raise NotImplementedError
|
137 |
+
|
138 |
+
def train_process(self, output, input_dict):
|
139 |
+
pass
|
140 |
+
|
141 |
+
def inference_forward(self, input_dict):
|
142 |
+
if input_dict["sample_method"] == "beam":
|
143 |
+
return self.beam_search(input_dict)
|
144 |
+
elif input_dict["sample_method"] == "dbs":
|
145 |
+
return self.diverse_beam_search(input_dict)
|
146 |
+
return self.stepwise_forward(input_dict)
|
147 |
+
|
148 |
+
def stepwise_forward(self, input_dict):
|
149 |
+
"""Step-by-step decoding"""
|
150 |
+
output = self.prepare_output(input_dict)
|
151 |
+
max_length = output["seq"].size(1)
|
152 |
+
# start sampling
|
153 |
+
for t in range(max_length):
|
154 |
+
input_dict["t"] = t
|
155 |
+
self.decode_step(input_dict, output)
|
156 |
+
if input_dict["mode"] == "inference": # decide whether to stop when sampling
|
157 |
+
unfinished_t = output["seq"][:, t] != self.end_idx
|
158 |
+
if t == 0:
|
159 |
+
unfinished = unfinished_t
|
160 |
+
else:
|
161 |
+
unfinished *= unfinished_t
|
162 |
+
output["seq"][:, t][~unfinished] = self.end_idx
|
163 |
+
if unfinished.sum() == 0:
|
164 |
+
break
|
165 |
+
self.stepwise_process(output)
|
166 |
+
return output
|
167 |
+
|
168 |
+
def decode_step(self, input_dict, output):
|
169 |
+
"""Decoding operation of timestep t"""
|
170 |
+
decoder_input = self.prepare_decoder_input(input_dict, output)
|
171 |
+
# feed to the decoder to get logit
|
172 |
+
output_t = self.decoder(decoder_input)
|
173 |
+
logit_t = output_t["logit"]
|
174 |
+
# assert logit_t.ndim == 3
|
175 |
+
if logit_t.size(1) == 1:
|
176 |
+
logit_t = logit_t.squeeze(1)
|
177 |
+
embed_t = output_t["embed"].squeeze(1)
|
178 |
+
elif logit_t.size(1) > 1:
|
179 |
+
logit_t = logit_t[:, -1, :]
|
180 |
+
embed_t = output_t["embed"][:, -1, :]
|
181 |
+
else:
|
182 |
+
raise Exception("no logit output")
|
183 |
+
# sample the next input word and get the corresponding logit
|
184 |
+
sampled = self.sample_next_word(logit_t,
|
185 |
+
method=input_dict["sample_method"],
|
186 |
+
temp=input_dict["temp"])
|
187 |
+
|
188 |
+
output_t.update(sampled)
|
189 |
+
output_t["t"] = input_dict["t"]
|
190 |
+
output_t["logit"] = logit_t
|
191 |
+
output_t["embed"] = embed_t
|
192 |
+
self.stepwise_process_step(output, output_t)
|
193 |
+
|
194 |
+
def prepare_decoder_input(self, input_dict, output):
|
195 |
+
"""Prepare the inp ut dict for the decoder"""
|
196 |
+
raise NotImplementedError
|
197 |
+
|
198 |
+
def stepwise_process_step(self, output, output_t):
|
199 |
+
"""Postprocessing (save output values) after each timestep t"""
|
200 |
+
t = output_t["t"]
|
201 |
+
output["logit"][:, t, :] = output_t["logit"]
|
202 |
+
output["seq"][:, t] = output_t["word"]
|
203 |
+
output["sampled_logprob"][:, t] = output_t["probs"]
|
204 |
+
output["embed"][:, t, :] = output_t["embed"]
|
205 |
+
|
206 |
+
def stepwise_process(self, output):
|
207 |
+
"""Postprocessing after the whole step-by-step autoregressive decoding"""
|
208 |
+
pass
|
209 |
+
|
210 |
+
def sample_next_word(self, logit, method, temp):
|
211 |
+
"""Sample the next word, given probs output by the decoder"""
|
212 |
+
logprob = torch.log_softmax(logit, dim=1)
|
213 |
+
if method == "greedy":
|
214 |
+
sampled_logprob, word = torch.max(logprob.detach(), 1)
|
215 |
+
elif method == "gumbel":
|
216 |
+
def sample_gumbel(shape, eps=1e-20):
|
217 |
+
U = torch.rand(shape).to(logprob.device)
|
218 |
+
return -torch.log(-torch.log(U + eps) + eps)
|
219 |
+
def gumbel_softmax_sample(logit, temperature):
|
220 |
+
y = logit + sample_gumbel(logit.size())
|
221 |
+
return torch.log_softmax(y / temperature, dim=-1)
|
222 |
+
_logprob = gumbel_softmax_sample(logprob, temp)
|
223 |
+
_, word = torch.max(_logprob.data, 1)
|
224 |
+
sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
|
225 |
+
else:
|
226 |
+
logprob = logprob / temp
|
227 |
+
if method.startswith("top"):
|
228 |
+
top_num = float(method[3:])
|
229 |
+
if 0 < top_num < 1: # top-p sampling
|
230 |
+
probs = torch.softmax(logit, dim=1)
|
231 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
|
232 |
+
_cumsum = sorted_probs.cumsum(1)
|
233 |
+
mask = _cumsum < top_num
|
234 |
+
mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
|
235 |
+
sorted_probs = sorted_probs * mask.to(sorted_probs)
|
236 |
+
sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
|
237 |
+
logprob.scatter_(1, sorted_indices, sorted_probs.log())
|
238 |
+
else: # top-k sampling
|
239 |
+
k = int(top_num)
|
240 |
+
tmp = torch.empty_like(logprob).fill_(float('-inf'))
|
241 |
+
topk, indices = torch.topk(logprob, k, dim=1)
|
242 |
+
tmp = tmp.scatter(1, indices, topk)
|
243 |
+
logprob = tmp
|
244 |
+
word = torch.distributions.Categorical(logits=logprob.detach()).sample()
|
245 |
+
sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
|
246 |
+
word = word.detach().long()
|
247 |
+
# sampled_logprob: [N,], word: [N,]
|
248 |
+
return {"word": word, "probs": sampled_logprob}
|
249 |
+
|
250 |
+
def beam_search(self, input_dict):
|
251 |
+
output = self.prepare_output(input_dict)
|
252 |
+
max_length = input_dict["max_length"]
|
253 |
+
beam_size = input_dict["beam_size"]
|
254 |
+
if input_dict["n_best"]:
|
255 |
+
n_best_size = input_dict["n_best_size"]
|
256 |
+
batch_size, max_length = output["seq"].size()
|
257 |
+
output["seq"] = torch.full((batch_size, n_best_size, max_length),
|
258 |
+
self.end_idx, dtype=torch.long)
|
259 |
+
|
260 |
+
temp = input_dict["temp"]
|
261 |
+
# instance by instance beam seach
|
262 |
+
for i in range(output["seq"].size(0)):
|
263 |
+
output_i = self.prepare_beamsearch_output(input_dict)
|
264 |
+
input_dict["sample_idx"] = i
|
265 |
+
for t in range(max_length):
|
266 |
+
input_dict["t"] = t
|
267 |
+
output_t = self.beamsearch_step(input_dict, output_i)
|
268 |
+
#######################################
|
269 |
+
# merge with previous beam and select the current max prob beam
|
270 |
+
#######################################
|
271 |
+
logit_t = output_t["logit"]
|
272 |
+
if logit_t.size(1) == 1:
|
273 |
+
logit_t = logit_t.squeeze(1)
|
274 |
+
elif logit_t.size(1) > 1:
|
275 |
+
logit_t = logit_t[:, -1, :]
|
276 |
+
else:
|
277 |
+
raise Exception("no logit output")
|
278 |
+
logprob_t = torch.log_softmax(logit_t, dim=1)
|
279 |
+
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
|
280 |
+
logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
|
281 |
+
if t == 0: # for the first step, all k seq will have the same probs
|
282 |
+
topk_logprob, topk_words = logprob_t[0].topk(
|
283 |
+
beam_size, 0, True, True)
|
284 |
+
else: # unroll and find top logprob, and their unrolled indices
|
285 |
+
topk_logprob, topk_words = logprob_t.view(-1).topk(
|
286 |
+
beam_size, 0, True, True)
|
287 |
+
topk_words = topk_words.cpu()
|
288 |
+
output_i["topk_logprob"] = topk_logprob
|
289 |
+
# output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
|
290 |
+
output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
|
291 |
+
rounding_mode='trunc')
|
292 |
+
output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
|
293 |
+
if t == 0:
|
294 |
+
output_i["seq"] = output_i["next_word"].unsqueeze(1)
|
295 |
+
else:
|
296 |
+
output_i["seq"] = torch.cat([
|
297 |
+
output_i["seq"][output_i["prev_words_beam"]],
|
298 |
+
output_i["next_word"].unsqueeze(1)], dim=1)
|
299 |
+
|
300 |
+
# add finished beams to results
|
301 |
+
is_end = output_i["next_word"] == self.end_idx
|
302 |
+
if t == max_length - 1:
|
303 |
+
is_end.fill_(1)
|
304 |
+
|
305 |
+
for beam_idx in range(beam_size):
|
306 |
+
if is_end[beam_idx]:
|
307 |
+
final_beam = {
|
308 |
+
"seq": output_i["seq"][beam_idx].clone(),
|
309 |
+
"score": output_i["topk_logprob"][beam_idx].item()
|
310 |
+
}
|
311 |
+
final_beam["score"] = final_beam["score"] / (t + 1)
|
312 |
+
output_i["done_beams"].append(final_beam)
|
313 |
+
output_i["topk_logprob"][is_end] -= 1000
|
314 |
+
|
315 |
+
self.beamsearch_process_step(output_i, output_t)
|
316 |
+
|
317 |
+
self.beamsearch_process(output, output_i, input_dict)
|
318 |
+
return output
|
319 |
+
|
320 |
+
def prepare_beamsearch_output(self, input_dict):
|
321 |
+
beam_size = input_dict["beam_size"]
|
322 |
+
device = input_dict["fc_emb"].device
|
323 |
+
output = {
|
324 |
+
"topk_logprob": torch.zeros(beam_size).to(device),
|
325 |
+
"seq": None,
|
326 |
+
"prev_words_beam": None,
|
327 |
+
"next_word": None,
|
328 |
+
"done_beams": [],
|
329 |
+
}
|
330 |
+
return output
|
331 |
+
|
332 |
+
def beamsearch_step(self, input_dict, output_i):
|
333 |
+
decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
|
334 |
+
output_t = self.decoder(decoder_input)
|
335 |
+
output_t["t"] = input_dict["t"]
|
336 |
+
return output_t
|
337 |
+
|
338 |
+
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
339 |
+
raise NotImplementedError
|
340 |
+
|
341 |
+
def beamsearch_process_step(self, output_i, output_t):
|
342 |
+
pass
|
343 |
+
|
344 |
+
def beamsearch_process(self, output, output_i, input_dict):
|
345 |
+
i = input_dict["sample_idx"]
|
346 |
+
done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
|
347 |
+
if input_dict["n_best"]:
|
348 |
+
done_beams = done_beams[:input_dict["n_best_size"]]
|
349 |
+
for out_idx, done_beam in enumerate(done_beams):
|
350 |
+
seq = done_beam["seq"]
|
351 |
+
output["seq"][i][out_idx, :len(seq)] = seq
|
352 |
+
else:
|
353 |
+
seq = done_beams[0]["seq"]
|
354 |
+
output["seq"][i][:len(seq)] = seq
|
355 |
+
|
356 |
+
def diverse_beam_search(self, input_dict):
|
357 |
+
|
358 |
+
def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
|
359 |
+
local_time = t - divm
|
360 |
+
unaug_logprob = logprob.clone()
|
361 |
+
|
362 |
+
if divm > 0:
|
363 |
+
change = torch.zeros(logprob.size(-1))
|
364 |
+
for prev_choice in range(divm):
|
365 |
+
prev_decisions = seq_table[prev_choice][..., local_time]
|
366 |
+
for prev_labels in range(bdash):
|
367 |
+
change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
|
368 |
+
|
369 |
+
change = change.to(logprob.device)
|
370 |
+
logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
|
371 |
+
|
372 |
+
return logprob, unaug_logprob
|
373 |
+
|
374 |
+
output = self.prepare_output(input_dict)
|
375 |
+
group_size = input_dict["group_size"]
|
376 |
+
batch_size = output["seq"].size(0)
|
377 |
+
beam_size = input_dict["beam_size"]
|
378 |
+
bdash = beam_size // group_size
|
379 |
+
input_dict["bdash"] = bdash
|
380 |
+
diversity_lambda = input_dict["diversity_lambda"]
|
381 |
+
device = input_dict["fc_emb"].device
|
382 |
+
max_length = input_dict["max_length"]
|
383 |
+
temp = input_dict["temp"]
|
384 |
+
group_nbest = input_dict["group_nbest"]
|
385 |
+
batch_size, max_length = output["seq"].size()
|
386 |
+
if group_nbest:
|
387 |
+
output["seq"] = torch.full((batch_size, beam_size, max_length),
|
388 |
+
self.end_idx, dtype=torch.long)
|
389 |
+
else:
|
390 |
+
output["seq"] = torch.full((batch_size, group_size, max_length),
|
391 |
+
self.end_idx, dtype=torch.long)
|
392 |
+
|
393 |
+
|
394 |
+
for i in range(batch_size):
|
395 |
+
input_dict["sample_idx"] = i
|
396 |
+
seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
|
397 |
+
logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
|
398 |
+
done_beams_table = [[] for _ in range(group_size)]
|
399 |
+
|
400 |
+
output_i = {
|
401 |
+
"prev_words_beam": [None for _ in range(group_size)],
|
402 |
+
"next_word": [None for _ in range(group_size)],
|
403 |
+
"state": [None for _ in range(group_size)]
|
404 |
+
}
|
405 |
+
|
406 |
+
for t in range(max_length + group_size - 1):
|
407 |
+
input_dict["t"] = t
|
408 |
+
for divm in range(group_size):
|
409 |
+
input_dict["divm"] = divm
|
410 |
+
if t >= divm and t <= max_length + divm - 1:
|
411 |
+
local_time = t - divm
|
412 |
+
decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
|
413 |
+
output_t = self.decoder(decoder_input)
|
414 |
+
output_t["divm"] = divm
|
415 |
+
logit_t = output_t["logit"]
|
416 |
+
if logit_t.size(1) == 1:
|
417 |
+
logit_t = logit_t.squeeze(1)
|
418 |
+
elif logit_t.size(1) > 1:
|
419 |
+
logit_t = logit_t[:, -1, :]
|
420 |
+
else:
|
421 |
+
raise Exception("no logit output")
|
422 |
+
logprob_t = torch.log_softmax(logit_t, dim=1)
|
423 |
+
logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
|
424 |
+
logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
|
425 |
+
logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
|
426 |
+
if local_time == 0: # for the first step, all k seq will have the same probs
|
427 |
+
topk_logprob, topk_words = logprob_t[0].topk(
|
428 |
+
bdash, 0, True, True)
|
429 |
+
else: # unroll and find top logprob, and their unrolled indices
|
430 |
+
topk_logprob, topk_words = logprob_t.view(-1).topk(
|
431 |
+
bdash, 0, True, True)
|
432 |
+
topk_words = topk_words.cpu()
|
433 |
+
logprob_table[divm] = topk_logprob
|
434 |
+
output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
|
435 |
+
output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
|
436 |
+
if local_time > 0:
|
437 |
+
seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
|
438 |
+
seq_table[divm] = torch.cat([
|
439 |
+
seq_table[divm],
|
440 |
+
output_i["next_word"][divm].unsqueeze(-1)], -1)
|
441 |
+
|
442 |
+
is_end = seq_table[divm][:, t-divm] == self.end_idx
|
443 |
+
assert seq_table[divm].shape[-1] == t - divm + 1
|
444 |
+
if t == max_length + divm - 1:
|
445 |
+
is_end.fill_(1)
|
446 |
+
for beam_idx in range(bdash):
|
447 |
+
if is_end[beam_idx]:
|
448 |
+
final_beam = {
|
449 |
+
"seq": seq_table[divm][beam_idx].clone(),
|
450 |
+
"score": logprob_table[divm][beam_idx].item()
|
451 |
+
}
|
452 |
+
final_beam["score"] = final_beam["score"] / (t - divm + 1)
|
453 |
+
done_beams_table[divm].append(final_beam)
|
454 |
+
logprob_table[divm][is_end] -= 1000
|
455 |
+
self.dbs_process_step(output_i, output_t)
|
456 |
+
done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
|
457 |
+
if group_nbest:
|
458 |
+
done_beams = sum(done_beams_table, [])
|
459 |
+
else:
|
460 |
+
done_beams = [group_beam[0] for group_beam in done_beams_table]
|
461 |
+
for _, done_beam in enumerate(done_beams):
|
462 |
+
output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
|
463 |
+
|
464 |
+
return output
|
465 |
+
|
466 |
+
def prepare_dbs_decoder_input(self, input_dict, output_i):
|
467 |
+
raise NotImplementedError
|
468 |
+
|
469 |
+
def dbs_process_step(self, output_i, output_t):
|
470 |
+
pass
|
471 |
+
|
472 |
+
|
473 |
+
class CaptionSequenceModel(nn.Module):
|
474 |
+
|
475 |
+
def __init__(self, model, seq_output_size):
|
476 |
+
super().__init__()
|
477 |
+
self.model = model
|
478 |
+
if model.decoder.d_model != seq_output_size:
|
479 |
+
self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
|
480 |
+
else:
|
481 |
+
self.output_transform = lambda x: x
|
482 |
+
|
483 |
+
def forward(self, input_dict):
|
484 |
+
output = self.model(input_dict)
|
485 |
+
|
486 |
+
if input_dict["mode"] == "train":
|
487 |
+
lens = input_dict["cap_len"] - 1
|
488 |
+
# seq_outputs: [N, d_model]
|
489 |
+
elif input_dict["mode"] == "inference":
|
490 |
+
if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
|
491 |
+
return output
|
492 |
+
seq = output["seq"]
|
493 |
+
lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
|
494 |
+
else:
|
495 |
+
raise Exception("mode should be either 'train' or 'inference'")
|
496 |
+
seq_output = mean_with_lens(output["embed"], lens)
|
497 |
+
seq_output = self.output_transform(seq_output)
|
498 |
+
output["seq_output"] = seq_output
|
499 |
+
return output
|
500 |
+
|
audio_to_text/captioning/models/decoder.py
ADDED
@@ -0,0 +1,746 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from .utils import generate_length_mask, init, PositionalEncoding
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDecoder(nn.Module):
|
14 |
+
"""
|
15 |
+
Take word/audio embeddings and output the next word probs
|
16 |
+
Base decoder, cannot be called directly
|
17 |
+
All decoders should inherit from this class
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim,
|
21 |
+
attn_emb_dim, dropout=0.2):
|
22 |
+
super().__init__()
|
23 |
+
self.emb_dim = emb_dim
|
24 |
+
self.vocab_size = vocab_size
|
25 |
+
self.fc_emb_dim = fc_emb_dim
|
26 |
+
self.attn_emb_dim = attn_emb_dim
|
27 |
+
self.word_embedding = nn.Embedding(vocab_size, emb_dim)
|
28 |
+
self.in_dropout = nn.Dropout(dropout)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
def load_word_embedding(self, weight, freeze=True):
|
34 |
+
embedding = np.load(weight)
|
35 |
+
assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
|
36 |
+
assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
|
37 |
+
|
38 |
+
# embeddings = torch.as_tensor(embeddings).float()
|
39 |
+
# self.word_embeddings.weight = nn.Parameter(embeddings)
|
40 |
+
# for para in self.word_embeddings.parameters():
|
41 |
+
# para.requires_grad = tune
|
42 |
+
self.word_embedding = nn.Embedding.from_pretrained(embedding,
|
43 |
+
freeze=freeze)
|
44 |
+
|
45 |
+
|
46 |
+
class RnnDecoder(BaseDecoder):
|
47 |
+
|
48 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
49 |
+
dropout, d_model, **kwargs):
|
50 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
51 |
+
dropout,)
|
52 |
+
self.d_model = d_model
|
53 |
+
self.num_layers = kwargs.get('num_layers', 1)
|
54 |
+
self.bidirectional = kwargs.get('bidirectional', False)
|
55 |
+
self.rnn_type = kwargs.get('rnn_type', "GRU")
|
56 |
+
self.classifier = nn.Linear(
|
57 |
+
self.d_model * (self.bidirectional + 1), vocab_size)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
def init_hidden(self, bs, device):
|
63 |
+
num_dire = self.bidirectional + 1
|
64 |
+
n_layer = self.num_layers
|
65 |
+
hid_dim = self.d_model
|
66 |
+
if self.rnn_type == "LSTM":
|
67 |
+
return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
|
68 |
+
torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
|
69 |
+
else:
|
70 |
+
return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
|
71 |
+
|
72 |
+
|
73 |
+
class RnnFcDecoder(RnnDecoder):
|
74 |
+
|
75 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs):
|
76 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs)
|
77 |
+
self.model = getattr(nn, self.rnn_type)(
|
78 |
+
input_size=self.emb_dim * 2,
|
79 |
+
hidden_size=self.d_model,
|
80 |
+
batch_first=True,
|
81 |
+
num_layers=self.num_layers,
|
82 |
+
bidirectional=self.bidirectional)
|
83 |
+
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
|
84 |
+
self.apply(init)
|
85 |
+
|
86 |
+
def forward(self, input_dict):
|
87 |
+
word = input_dict["word"]
|
88 |
+
state = input_dict.get("state", None)
|
89 |
+
fc_emb = input_dict["fc_emb"]
|
90 |
+
|
91 |
+
word = word.to(fc_emb.device)
|
92 |
+
embed = self.in_dropout(self.word_embedding(word))
|
93 |
+
|
94 |
+
p_fc_emb = self.fc_proj(fc_emb)
|
95 |
+
# embed: [N, T, embed_size]
|
96 |
+
embed = torch.cat((embed, p_fc_emb), dim=-1)
|
97 |
+
|
98 |
+
out, state = self.model(embed, state)
|
99 |
+
# out: [N, T, hs], states: [num_layers * num_dire, N, hs]
|
100 |
+
logits = self.classifier(out)
|
101 |
+
output = {
|
102 |
+
"state": state,
|
103 |
+
"embeds": out,
|
104 |
+
"logits": logits
|
105 |
+
}
|
106 |
+
|
107 |
+
return output
|
108 |
+
|
109 |
+
|
110 |
+
class Seq2SeqAttention(nn.Module):
|
111 |
+
|
112 |
+
def __init__(self, hs_enc, hs_dec, attn_size):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
hs_enc: encoder hidden size
|
116 |
+
hs_dec: decoder hidden size
|
117 |
+
attn_size: attention vector size
|
118 |
+
"""
|
119 |
+
super(Seq2SeqAttention, self).__init__()
|
120 |
+
self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
|
121 |
+
self.v = nn.Parameter(torch.randn(attn_size))
|
122 |
+
self.apply(init)
|
123 |
+
|
124 |
+
def forward(self, h_dec, h_enc, src_lens):
|
125 |
+
"""
|
126 |
+
Args:
|
127 |
+
h_dec: decoder hidden (query), [N, hs_dec]
|
128 |
+
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
|
129 |
+
src_lens: source (encoder memory) lengths, [N, ]
|
130 |
+
"""
|
131 |
+
N = h_enc.size(0)
|
132 |
+
src_max_len = h_enc.size(1)
|
133 |
+
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
|
134 |
+
|
135 |
+
attn_input = torch.cat((h_dec, h_enc), dim=-1)
|
136 |
+
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
|
137 |
+
|
138 |
+
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
|
139 |
+
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
|
140 |
+
|
141 |
+
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
|
142 |
+
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
|
143 |
+
|
144 |
+
score = score.masked_fill(mask == 0, -1e10)
|
145 |
+
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
|
146 |
+
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
|
147 |
+
|
148 |
+
return ctx, weights
|
149 |
+
|
150 |
+
|
151 |
+
class AttentionProj(nn.Module):
|
152 |
+
|
153 |
+
def __init__(self, hs_enc, hs_dec, embed_dim, attn_size):
|
154 |
+
self.q_proj = nn.Linear(hs_dec, embed_dim)
|
155 |
+
self.kv_proj = nn.Linear(hs_enc, embed_dim)
|
156 |
+
self.h2attn = nn.Linear(embed_dim * 2, attn_size)
|
157 |
+
self.v = nn.Parameter(torch.randn(attn_size))
|
158 |
+
self.apply(init)
|
159 |
+
|
160 |
+
def init(self, m):
|
161 |
+
if isinstance(m, nn.Linear):
|
162 |
+
nn.init.kaiming_uniform_(m.weight)
|
163 |
+
if m.bias is not None:
|
164 |
+
nn.init.constant_(m.bias, 0)
|
165 |
+
|
166 |
+
def forward(self, h_dec, h_enc, src_lens):
|
167 |
+
"""
|
168 |
+
Args:
|
169 |
+
h_dec: decoder hidden (query), [N, hs_dec]
|
170 |
+
h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
|
171 |
+
src_lens: source (encoder memory) lengths, [N, ]
|
172 |
+
"""
|
173 |
+
h_enc = self.kv_proj(h_enc) # [N, src_max_len, embed_dim]
|
174 |
+
h_dec = self.q_proj(h_dec) # [N, embed_dim]
|
175 |
+
N = h_enc.size(0)
|
176 |
+
src_max_len = h_enc.size(1)
|
177 |
+
h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
|
178 |
+
|
179 |
+
attn_input = torch.cat((h_dec, h_enc), dim=-1)
|
180 |
+
attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
|
181 |
+
|
182 |
+
v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
|
183 |
+
score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
|
184 |
+
|
185 |
+
idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
|
186 |
+
mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
|
187 |
+
|
188 |
+
score = score.masked_fill(mask == 0, -1e10)
|
189 |
+
weights = torch.softmax(score, dim=-1) # [N, src_max_len]
|
190 |
+
ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
|
191 |
+
|
192 |
+
return ctx, weights
|
193 |
+
|
194 |
+
|
195 |
+
class BahAttnDecoder(RnnDecoder):
|
196 |
+
|
197 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
198 |
+
dropout, d_model, **kwargs):
|
199 |
+
"""
|
200 |
+
concatenate fc, attn, word to feed to the rnn
|
201 |
+
"""
|
202 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
203 |
+
dropout, d_model, **kwargs)
|
204 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
205 |
+
self.model = getattr(nn, self.rnn_type)(
|
206 |
+
input_size=self.emb_dim * 3,
|
207 |
+
hidden_size=self.d_model,
|
208 |
+
batch_first=True,
|
209 |
+
num_layers=self.num_layers,
|
210 |
+
bidirectional=self.bidirectional)
|
211 |
+
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
212 |
+
self.d_model * (self.bidirectional + 1) * \
|
213 |
+
self.num_layers,
|
214 |
+
attn_size)
|
215 |
+
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
|
216 |
+
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
217 |
+
self.apply(init)
|
218 |
+
|
219 |
+
def forward(self, input_dict):
|
220 |
+
word = input_dict["word"]
|
221 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
222 |
+
fc_emb = input_dict["fc_emb"]
|
223 |
+
attn_emb = input_dict["attn_emb"]
|
224 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
225 |
+
|
226 |
+
word = word.to(fc_emb.device)
|
227 |
+
embed = self.in_dropout(self.word_embedding(word))
|
228 |
+
|
229 |
+
# embed: [N, 1, embed_size]
|
230 |
+
if state is None:
|
231 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
232 |
+
if self.rnn_type == "LSTM":
|
233 |
+
query = state[0].transpose(0, 1).flatten(1)
|
234 |
+
else:
|
235 |
+
query = state.transpose(0, 1).flatten(1)
|
236 |
+
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
237 |
+
|
238 |
+
p_fc_emb = self.fc_proj(fc_emb)
|
239 |
+
p_ctx = self.ctx_proj(c)
|
240 |
+
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
|
241 |
+
dim=-1)
|
242 |
+
|
243 |
+
out, state = self.model(rnn_input, state)
|
244 |
+
|
245 |
+
output = {
|
246 |
+
"state": state,
|
247 |
+
"embed": out,
|
248 |
+
"logit": self.classifier(out),
|
249 |
+
"attn_weight": attn_weight
|
250 |
+
}
|
251 |
+
return output
|
252 |
+
|
253 |
+
|
254 |
+
class BahAttnDecoder2(RnnDecoder):
|
255 |
+
|
256 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
257 |
+
dropout, d_model, **kwargs):
|
258 |
+
"""
|
259 |
+
add fc, attn, word together to feed to the rnn
|
260 |
+
"""
|
261 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
262 |
+
dropout, d_model, **kwargs)
|
263 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
264 |
+
self.model = getattr(nn, self.rnn_type)(
|
265 |
+
input_size=self.emb_dim,
|
266 |
+
hidden_size=self.d_model,
|
267 |
+
batch_first=True,
|
268 |
+
num_layers=self.num_layers,
|
269 |
+
bidirectional=self.bidirectional)
|
270 |
+
self.attn = Seq2SeqAttention(self.emb_dim,
|
271 |
+
self.d_model * (self.bidirectional + 1) * \
|
272 |
+
self.num_layers,
|
273 |
+
attn_size)
|
274 |
+
self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
|
275 |
+
self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
276 |
+
self.apply(partial(init, method="xavier"))
|
277 |
+
|
278 |
+
def forward(self, input_dict):
|
279 |
+
word = input_dict["word"]
|
280 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
281 |
+
fc_emb = input_dict["fc_emb"]
|
282 |
+
attn_emb = input_dict["attn_emb"]
|
283 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
284 |
+
|
285 |
+
word = word.to(fc_emb.device)
|
286 |
+
embed = self.in_dropout(self.word_embedding(word))
|
287 |
+
p_attn_emb = self.attn_proj(attn_emb)
|
288 |
+
|
289 |
+
# embed: [N, 1, embed_size]
|
290 |
+
if state is None:
|
291 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
292 |
+
if self.rnn_type == "LSTM":
|
293 |
+
query = state[0].transpose(0, 1).flatten(1)
|
294 |
+
else:
|
295 |
+
query = state.transpose(0, 1).flatten(1)
|
296 |
+
c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len)
|
297 |
+
|
298 |
+
p_fc_emb = self.fc_proj(fc_emb)
|
299 |
+
rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1)
|
300 |
+
|
301 |
+
out, state = self.model(rnn_input, state)
|
302 |
+
|
303 |
+
output = {
|
304 |
+
"state": state,
|
305 |
+
"embed": out,
|
306 |
+
"logit": self.classifier(out),
|
307 |
+
"attn_weight": attn_weight
|
308 |
+
}
|
309 |
+
return output
|
310 |
+
|
311 |
+
|
312 |
+
class ConditionalBahAttnDecoder(RnnDecoder):
|
313 |
+
|
314 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
315 |
+
dropout, d_model, **kwargs):
|
316 |
+
"""
|
317 |
+
concatenate fc, attn, word to feed to the rnn
|
318 |
+
"""
|
319 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
320 |
+
dropout, d_model, **kwargs)
|
321 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
322 |
+
self.model = getattr(nn, self.rnn_type)(
|
323 |
+
input_size=self.emb_dim * 3,
|
324 |
+
hidden_size=self.d_model,
|
325 |
+
batch_first=True,
|
326 |
+
num_layers=self.num_layers,
|
327 |
+
bidirectional=self.bidirectional)
|
328 |
+
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
329 |
+
self.d_model * (self.bidirectional + 1) * \
|
330 |
+
self.num_layers,
|
331 |
+
attn_size)
|
332 |
+
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
333 |
+
self.condition_embedding = nn.Embedding(2, emb_dim)
|
334 |
+
self.apply(init)
|
335 |
+
|
336 |
+
def forward(self, input_dict):
|
337 |
+
word = input_dict["word"]
|
338 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
339 |
+
fc_emb = input_dict["fc_emb"]
|
340 |
+
attn_emb = input_dict["attn_emb"]
|
341 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
342 |
+
condition = input_dict["condition"]
|
343 |
+
|
344 |
+
word = word.to(fc_emb.device)
|
345 |
+
embed = self.in_dropout(self.word_embedding(word))
|
346 |
+
|
347 |
+
condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device)
|
348 |
+
condition_emb = torch.matmul(condition, self.condition_embedding.weight)
|
349 |
+
# condition_embs: [N, emb_dim]
|
350 |
+
|
351 |
+
# embed: [N, 1, embed_size]
|
352 |
+
if state is None:
|
353 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
354 |
+
if self.rnn_type == "LSTM":
|
355 |
+
query = state[0].transpose(0, 1).flatten(1)
|
356 |
+
else:
|
357 |
+
query = state.transpose(0, 1).flatten(1)
|
358 |
+
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
359 |
+
|
360 |
+
p_ctx = self.ctx_proj(c)
|
361 |
+
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)),
|
362 |
+
dim=-1)
|
363 |
+
|
364 |
+
out, state = self.model(rnn_input, state)
|
365 |
+
|
366 |
+
output = {
|
367 |
+
"state": state,
|
368 |
+
"embed": out,
|
369 |
+
"logit": self.classifier(out),
|
370 |
+
"attn_weight": attn_weight
|
371 |
+
}
|
372 |
+
return output
|
373 |
+
|
374 |
+
|
375 |
+
class StructBahAttnDecoder(RnnDecoder):
|
376 |
+
|
377 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size,
|
378 |
+
attn_emb_dim, dropout, d_model, **kwargs):
|
379 |
+
"""
|
380 |
+
concatenate fc, attn, word to feed to the rnn
|
381 |
+
"""
|
382 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
383 |
+
dropout, d_model, **kwargs)
|
384 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
385 |
+
self.model = getattr(nn, self.rnn_type)(
|
386 |
+
input_size=self.emb_dim * 3,
|
387 |
+
hidden_size=self.d_model,
|
388 |
+
batch_first=True,
|
389 |
+
num_layers=self.num_layers,
|
390 |
+
bidirectional=self.bidirectional)
|
391 |
+
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
392 |
+
self.d_model * (self.bidirectional + 1) * \
|
393 |
+
self.num_layers,
|
394 |
+
attn_size)
|
395 |
+
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
396 |
+
self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim)
|
397 |
+
self.apply(init)
|
398 |
+
|
399 |
+
def forward(self, input_dict):
|
400 |
+
word = input_dict["word"]
|
401 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
402 |
+
fc_emb = input_dict["fc_emb"]
|
403 |
+
attn_emb = input_dict["attn_emb"]
|
404 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
405 |
+
structure = input_dict["structure"]
|
406 |
+
|
407 |
+
word = word.to(fc_emb.device)
|
408 |
+
embed = self.in_dropout(self.word_embedding(word))
|
409 |
+
|
410 |
+
struct_emb = self.struct_embedding(structure)
|
411 |
+
# struct_embs: [N, emb_dim]
|
412 |
+
|
413 |
+
# embed: [N, 1, embed_size]
|
414 |
+
if state is None:
|
415 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
416 |
+
if self.rnn_type == "LSTM":
|
417 |
+
query = state[0].transpose(0, 1).flatten(1)
|
418 |
+
else:
|
419 |
+
query = state.transpose(0, 1).flatten(1)
|
420 |
+
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
421 |
+
|
422 |
+
p_ctx = self.ctx_proj(c)
|
423 |
+
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1)
|
424 |
+
|
425 |
+
out, state = self.model(rnn_input, state)
|
426 |
+
|
427 |
+
output = {
|
428 |
+
"state": state,
|
429 |
+
"embed": out,
|
430 |
+
"logit": self.classifier(out),
|
431 |
+
"attn_weight": attn_weight
|
432 |
+
}
|
433 |
+
return output
|
434 |
+
|
435 |
+
|
436 |
+
class StyleBahAttnDecoder(RnnDecoder):
|
437 |
+
|
438 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
439 |
+
dropout, d_model, **kwargs):
|
440 |
+
"""
|
441 |
+
concatenate fc, attn, word to feed to the rnn
|
442 |
+
"""
|
443 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
444 |
+
dropout, d_model, **kwargs)
|
445 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
446 |
+
self.model = getattr(nn, self.rnn_type)(
|
447 |
+
input_size=self.emb_dim * 3,
|
448 |
+
hidden_size=self.d_model,
|
449 |
+
batch_first=True,
|
450 |
+
num_layers=self.num_layers,
|
451 |
+
bidirectional=self.bidirectional)
|
452 |
+
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
453 |
+
self.d_model * (self.bidirectional + 1) * \
|
454 |
+
self.num_layers,
|
455 |
+
attn_size)
|
456 |
+
self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
|
457 |
+
self.apply(init)
|
458 |
+
|
459 |
+
def forward(self, input_dict):
|
460 |
+
word = input_dict["word"]
|
461 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
462 |
+
fc_emb = input_dict["fc_emb"]
|
463 |
+
attn_emb = input_dict["attn_emb"]
|
464 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
465 |
+
style = input_dict["style"]
|
466 |
+
|
467 |
+
word = word.to(fc_emb.device)
|
468 |
+
embed = self.in_dropout(self.word_embedding(word))
|
469 |
+
|
470 |
+
# embed: [N, 1, embed_size]
|
471 |
+
if state is None:
|
472 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
473 |
+
if self.rnn_type == "LSTM":
|
474 |
+
query = state[0].transpose(0, 1).flatten(1)
|
475 |
+
else:
|
476 |
+
query = state.transpose(0, 1).flatten(1)
|
477 |
+
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
478 |
+
|
479 |
+
p_ctx = self.ctx_proj(c)
|
480 |
+
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)),
|
481 |
+
dim=-1)
|
482 |
+
|
483 |
+
out, state = self.model(rnn_input, state)
|
484 |
+
|
485 |
+
output = {
|
486 |
+
"state": state,
|
487 |
+
"embed": out,
|
488 |
+
"logit": self.classifier(out),
|
489 |
+
"attn_weight": attn_weight
|
490 |
+
}
|
491 |
+
return output
|
492 |
+
|
493 |
+
|
494 |
+
class BahAttnDecoder3(RnnDecoder):
|
495 |
+
|
496 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
497 |
+
dropout, d_model, **kwargs):
|
498 |
+
"""
|
499 |
+
concatenate fc, attn, word to feed to the rnn
|
500 |
+
"""
|
501 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
502 |
+
dropout, d_model, **kwargs)
|
503 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
504 |
+
self.model = getattr(nn, self.rnn_type)(
|
505 |
+
input_size=self.emb_dim + attn_emb_dim,
|
506 |
+
hidden_size=self.d_model,
|
507 |
+
batch_first=True,
|
508 |
+
num_layers=self.num_layers,
|
509 |
+
bidirectional=self.bidirectional)
|
510 |
+
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
511 |
+
self.d_model * (self.bidirectional + 1) * \
|
512 |
+
self.num_layers,
|
513 |
+
attn_size)
|
514 |
+
self.ctx_proj = lambda x: x
|
515 |
+
self.apply(init)
|
516 |
+
|
517 |
+
def forward(self, input_dict):
|
518 |
+
word = input_dict["word"]
|
519 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
520 |
+
fc_emb = input_dict["fc_emb"]
|
521 |
+
attn_emb = input_dict["attn_emb"]
|
522 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
523 |
+
|
524 |
+
if word.size(-1) == self.fc_emb_dim: # fc_emb
|
525 |
+
embed = word.unsqueeze(1)
|
526 |
+
elif word.size(-1) == 1: # word
|
527 |
+
word = word.to(fc_emb.device)
|
528 |
+
embed = self.in_dropout(self.word_embedding(word))
|
529 |
+
else:
|
530 |
+
raise Exception(f"problem with word input size {word.size()}")
|
531 |
+
|
532 |
+
# embed: [N, 1, embed_size]
|
533 |
+
if state is None:
|
534 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
535 |
+
if self.rnn_type == "LSTM":
|
536 |
+
query = state[0].transpose(0, 1).flatten(1)
|
537 |
+
else:
|
538 |
+
query = state.transpose(0, 1).flatten(1)
|
539 |
+
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
540 |
+
|
541 |
+
p_ctx = self.ctx_proj(c)
|
542 |
+
rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1)
|
543 |
+
|
544 |
+
out, state = self.model(rnn_input, state)
|
545 |
+
|
546 |
+
output = {
|
547 |
+
"state": state,
|
548 |
+
"embed": out,
|
549 |
+
"logit": self.classifier(out),
|
550 |
+
"attn_weight": attn_weight
|
551 |
+
}
|
552 |
+
return output
|
553 |
+
|
554 |
+
|
555 |
+
class SpecificityBahAttnDecoder(RnnDecoder):
|
556 |
+
|
557 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
558 |
+
dropout, d_model, **kwargs):
|
559 |
+
"""
|
560 |
+
concatenate fc, attn, word to feed to the rnn
|
561 |
+
"""
|
562 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
563 |
+
dropout, d_model, **kwargs)
|
564 |
+
attn_size = kwargs.get("attn_size", self.d_model)
|
565 |
+
self.model = getattr(nn, self.rnn_type)(
|
566 |
+
input_size=self.emb_dim + attn_emb_dim + 1,
|
567 |
+
hidden_size=self.d_model,
|
568 |
+
batch_first=True,
|
569 |
+
num_layers=self.num_layers,
|
570 |
+
bidirectional=self.bidirectional)
|
571 |
+
self.attn = Seq2SeqAttention(self.attn_emb_dim,
|
572 |
+
self.d_model * (self.bidirectional + 1) * \
|
573 |
+
self.num_layers,
|
574 |
+
attn_size)
|
575 |
+
self.ctx_proj = lambda x: x
|
576 |
+
self.apply(init)
|
577 |
+
|
578 |
+
def forward(self, input_dict):
|
579 |
+
word = input_dict["word"]
|
580 |
+
state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
|
581 |
+
fc_emb = input_dict["fc_emb"]
|
582 |
+
attn_emb = input_dict["attn_emb"]
|
583 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
584 |
+
condition = input_dict["condition"] # [N,]
|
585 |
+
|
586 |
+
word = word.to(fc_emb.device)
|
587 |
+
embed = self.in_dropout(self.word_embedding(word))
|
588 |
+
|
589 |
+
# embed: [N, 1, embed_size]
|
590 |
+
if state is None:
|
591 |
+
state = self.init_hidden(word.size(0), fc_emb.device)
|
592 |
+
if self.rnn_type == "LSTM":
|
593 |
+
query = state[0].transpose(0, 1).flatten(1)
|
594 |
+
else:
|
595 |
+
query = state.transpose(0, 1).flatten(1)
|
596 |
+
c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
|
597 |
+
|
598 |
+
p_ctx = self.ctx_proj(c)
|
599 |
+
rnn_input = torch.cat(
|
600 |
+
(embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)),
|
601 |
+
dim=-1)
|
602 |
+
|
603 |
+
out, state = self.model(rnn_input, state)
|
604 |
+
|
605 |
+
output = {
|
606 |
+
"state": state,
|
607 |
+
"embed": out,
|
608 |
+
"logit": self.classifier(out),
|
609 |
+
"attn_weight": attn_weight
|
610 |
+
}
|
611 |
+
return output
|
612 |
+
|
613 |
+
|
614 |
+
class TransformerDecoder(BaseDecoder):
|
615 |
+
|
616 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs):
|
617 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
618 |
+
dropout=dropout,)
|
619 |
+
self.d_model = emb_dim
|
620 |
+
self.nhead = kwargs.get("nhead", self.d_model // 64)
|
621 |
+
self.nlayers = kwargs.get("nlayers", 2)
|
622 |
+
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
|
623 |
+
|
624 |
+
self.pos_encoder = PositionalEncoding(self.d_model, dropout)
|
625 |
+
layer = nn.TransformerDecoderLayer(d_model=self.d_model,
|
626 |
+
nhead=self.nhead,
|
627 |
+
dim_feedforward=self.dim_feedforward,
|
628 |
+
dropout=dropout)
|
629 |
+
self.model = nn.TransformerDecoder(layer, self.nlayers)
|
630 |
+
self.classifier = nn.Linear(self.d_model, vocab_size)
|
631 |
+
self.attn_proj = nn.Sequential(
|
632 |
+
nn.Linear(self.attn_emb_dim, self.d_model),
|
633 |
+
nn.ReLU(),
|
634 |
+
nn.Dropout(dropout),
|
635 |
+
nn.LayerNorm(self.d_model)
|
636 |
+
)
|
637 |
+
# self.attn_proj = lambda x: x
|
638 |
+
self.init_params()
|
639 |
+
|
640 |
+
def init_params(self):
|
641 |
+
for p in self.parameters():
|
642 |
+
if p.dim() > 1:
|
643 |
+
nn.init.xavier_uniform_(p)
|
644 |
+
|
645 |
+
def generate_square_subsequent_mask(self, max_length):
|
646 |
+
mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
|
647 |
+
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
648 |
+
return mask
|
649 |
+
|
650 |
+
def forward(self, input_dict):
|
651 |
+
word = input_dict["word"]
|
652 |
+
attn_emb = input_dict["attn_emb"]
|
653 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
654 |
+
cap_padding_mask = input_dict["cap_padding_mask"]
|
655 |
+
|
656 |
+
p_attn_emb = self.attn_proj(attn_emb)
|
657 |
+
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
658 |
+
word = word.to(attn_emb.device)
|
659 |
+
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
660 |
+
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
661 |
+
embed = self.pos_encoder(embed)
|
662 |
+
|
663 |
+
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
664 |
+
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
665 |
+
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
666 |
+
tgt_key_padding_mask=cap_padding_mask,
|
667 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
668 |
+
output = output.transpose(0, 1)
|
669 |
+
output = {
|
670 |
+
"embed": output,
|
671 |
+
"logit": self.classifier(output),
|
672 |
+
}
|
673 |
+
return output
|
674 |
+
|
675 |
+
|
676 |
+
|
677 |
+
|
678 |
+
class EventTransformerDecoder(TransformerDecoder):
|
679 |
+
|
680 |
+
def forward(self, input_dict):
|
681 |
+
word = input_dict["word"] # index of word embeddings
|
682 |
+
attn_emb = input_dict["attn_emb"]
|
683 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
684 |
+
cap_padding_mask = input_dict["cap_padding_mask"]
|
685 |
+
event_emb = input_dict["event"] # [N, emb_dim]
|
686 |
+
|
687 |
+
p_attn_emb = self.attn_proj(attn_emb)
|
688 |
+
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
689 |
+
word = word.to(attn_emb.device)
|
690 |
+
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
691 |
+
|
692 |
+
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
693 |
+
embed += event_emb
|
694 |
+
embed = self.pos_encoder(embed)
|
695 |
+
|
696 |
+
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
697 |
+
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
698 |
+
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
699 |
+
tgt_key_padding_mask=cap_padding_mask,
|
700 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
701 |
+
output = output.transpose(0, 1)
|
702 |
+
output = {
|
703 |
+
"embed": output,
|
704 |
+
"logit": self.classifier(output),
|
705 |
+
}
|
706 |
+
return output
|
707 |
+
|
708 |
+
|
709 |
+
class KeywordProbTransformerDecoder(TransformerDecoder):
|
710 |
+
|
711 |
+
def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
712 |
+
dropout, keyword_classes_num, **kwargs):
|
713 |
+
super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
|
714 |
+
dropout, **kwargs)
|
715 |
+
self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
|
716 |
+
self.word_keyword_norm = nn.LayerNorm(self.d_model)
|
717 |
+
|
718 |
+
def forward(self, input_dict):
|
719 |
+
word = input_dict["word"] # index of word embeddings
|
720 |
+
attn_emb = input_dict["attn_emb"]
|
721 |
+
attn_emb_len = input_dict["attn_emb_len"]
|
722 |
+
cap_padding_mask = input_dict["cap_padding_mask"]
|
723 |
+
keyword = input_dict["keyword"] # [N, keyword_classes_num]
|
724 |
+
|
725 |
+
p_attn_emb = self.attn_proj(attn_emb)
|
726 |
+
p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
|
727 |
+
word = word.to(attn_emb.device)
|
728 |
+
embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
|
729 |
+
|
730 |
+
embed = embed.transpose(0, 1) # [T, N, emb_dim]
|
731 |
+
embed += self.keyword_proj(keyword)
|
732 |
+
embed = self.word_keyword_norm(embed)
|
733 |
+
|
734 |
+
embed = self.pos_encoder(embed)
|
735 |
+
|
736 |
+
tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
|
737 |
+
memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
|
738 |
+
output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
|
739 |
+
tgt_key_padding_mask=cap_padding_mask,
|
740 |
+
memory_key_padding_mask=memory_key_padding_mask)
|
741 |
+
output = output.transpose(0, 1)
|
742 |
+
output = {
|
743 |
+
"embed": output,
|
744 |
+
"logit": self.classifier(output),
|
745 |
+
}
|
746 |
+
return output
|
audio_to_text/captioning/models/encoder.py
ADDED
@@ -0,0 +1,686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchaudio import transforms
|
10 |
+
from torchlibrosa.augmentation import SpecAugmentation
|
11 |
+
|
12 |
+
from .utils import mean_with_lens, max_with_lens, \
|
13 |
+
init, pack_wrapper, generate_length_mask, PositionalEncoding
|
14 |
+
|
15 |
+
|
16 |
+
def init_layer(layer):
|
17 |
+
"""Initialize a Linear or Convolutional layer. """
|
18 |
+
nn.init.xavier_uniform_(layer.weight)
|
19 |
+
|
20 |
+
if hasattr(layer, 'bias'):
|
21 |
+
if layer.bias is not None:
|
22 |
+
layer.bias.data.fill_(0.)
|
23 |
+
|
24 |
+
|
25 |
+
def init_bn(bn):
|
26 |
+
"""Initialize a Batchnorm layer. """
|
27 |
+
bn.bias.data.fill_(0.)
|
28 |
+
bn.weight.data.fill_(1.)
|
29 |
+
|
30 |
+
|
31 |
+
class BaseEncoder(nn.Module):
|
32 |
+
|
33 |
+
"""
|
34 |
+
Encode the given audio into embedding
|
35 |
+
Base encoder class, cannot be called directly
|
36 |
+
All encoders should inherit from this class
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
|
40 |
+
super(BaseEncoder, self).__init__()
|
41 |
+
self.spec_dim = spec_dim
|
42 |
+
self.fc_feat_dim = fc_feat_dim
|
43 |
+
self.attn_feat_dim = attn_feat_dim
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
#########################
|
48 |
+
# an encoder first encodes audio feature into embedding, obtaining
|
49 |
+
# `encoded`: {
|
50 |
+
# fc_embs: [N, fc_emb_dim],
|
51 |
+
# attn_embs: [N, attn_max_len, attn_emb_dim],
|
52 |
+
# attn_emb_lens: [N,]
|
53 |
+
# }
|
54 |
+
#########################
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
|
58 |
+
class Block2D(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, cin, cout, kernel_size=3, padding=1):
|
61 |
+
super().__init__()
|
62 |
+
self.block = nn.Sequential(
|
63 |
+
nn.BatchNorm2d(cin),
|
64 |
+
nn.Conv2d(cin,
|
65 |
+
cout,
|
66 |
+
kernel_size=kernel_size,
|
67 |
+
padding=padding,
|
68 |
+
bias=False),
|
69 |
+
nn.LeakyReLU(inplace=True, negative_slope=0.1))
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return self.block(x)
|
73 |
+
|
74 |
+
|
75 |
+
class LinearSoftPool(nn.Module):
|
76 |
+
"""LinearSoftPool
|
77 |
+
Linear softmax, takes logits and returns a probability, near to the actual maximum value.
|
78 |
+
Taken from the paper:
|
79 |
+
A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
|
80 |
+
https://arxiv.org/abs/1810.09050
|
81 |
+
"""
|
82 |
+
def __init__(self, pooldim=1):
|
83 |
+
super().__init__()
|
84 |
+
self.pooldim = pooldim
|
85 |
+
|
86 |
+
def forward(self, logits, time_decision):
|
87 |
+
return (time_decision**2).sum(self.pooldim) / time_decision.sum(
|
88 |
+
self.pooldim)
|
89 |
+
|
90 |
+
|
91 |
+
class MeanPool(nn.Module):
|
92 |
+
|
93 |
+
def __init__(self, pooldim=1):
|
94 |
+
super().__init__()
|
95 |
+
self.pooldim = pooldim
|
96 |
+
|
97 |
+
def forward(self, logits, decision):
|
98 |
+
return torch.mean(decision, dim=self.pooldim)
|
99 |
+
|
100 |
+
|
101 |
+
class AttentionPool(nn.Module):
|
102 |
+
"""docstring for AttentionPool"""
|
103 |
+
def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
|
104 |
+
super().__init__()
|
105 |
+
self.inputdim = inputdim
|
106 |
+
self.outputdim = outputdim
|
107 |
+
self.pooldim = pooldim
|
108 |
+
self.transform = nn.Linear(inputdim, outputdim)
|
109 |
+
self.activ = nn.Softmax(dim=self.pooldim)
|
110 |
+
self.eps = 1e-7
|
111 |
+
|
112 |
+
def forward(self, logits, decision):
|
113 |
+
# Input is (B, T, D)
|
114 |
+
# B, T, D
|
115 |
+
w = self.activ(torch.clamp(self.transform(logits), -15, 15))
|
116 |
+
detect = (decision * w).sum(
|
117 |
+
self.pooldim) / (w.sum(self.pooldim) + self.eps)
|
118 |
+
# B, T, D
|
119 |
+
return detect
|
120 |
+
|
121 |
+
|
122 |
+
class MMPool(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self, dims):
|
125 |
+
super().__init__()
|
126 |
+
self.avgpool = nn.AvgPool2d(dims)
|
127 |
+
self.maxpool = nn.MaxPool2d(dims)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
return self.avgpool(x) + self.maxpool(x)
|
131 |
+
|
132 |
+
|
133 |
+
def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
|
134 |
+
"""parse_poolingfunction
|
135 |
+
A heler function to parse any temporal pooling
|
136 |
+
Pooling is done on dimension 1
|
137 |
+
:param poolingfunction_name:
|
138 |
+
:param **kwargs:
|
139 |
+
"""
|
140 |
+
poolingfunction_name = poolingfunction_name.lower()
|
141 |
+
if poolingfunction_name == 'mean':
|
142 |
+
return MeanPool(pooldim=1)
|
143 |
+
elif poolingfunction_name == 'linear':
|
144 |
+
return LinearSoftPool(pooldim=1)
|
145 |
+
elif poolingfunction_name == 'attention':
|
146 |
+
return AttentionPool(inputdim=kwargs['inputdim'],
|
147 |
+
outputdim=kwargs['outputdim'])
|
148 |
+
|
149 |
+
|
150 |
+
def embedding_pooling(x, lens, pooling="mean"):
|
151 |
+
if pooling == "max":
|
152 |
+
fc_embs = max_with_lens(x, lens)
|
153 |
+
elif pooling == "mean":
|
154 |
+
fc_embs = mean_with_lens(x, lens)
|
155 |
+
elif pooling == "mean+max":
|
156 |
+
x_mean = mean_with_lens(x, lens)
|
157 |
+
x_max = max_with_lens(x, lens)
|
158 |
+
fc_embs = x_mean + x_max
|
159 |
+
elif pooling == "last":
|
160 |
+
indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
|
161 |
+
# indices: [N, 1, hidden]
|
162 |
+
fc_embs = torch.gather(x, 1, indices).squeeze(1)
|
163 |
+
else:
|
164 |
+
raise Exception(f"pooling method {pooling} not support")
|
165 |
+
return fc_embs
|
166 |
+
|
167 |
+
|
168 |
+
class Cdur5Encoder(BaseEncoder):
|
169 |
+
|
170 |
+
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
|
171 |
+
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
172 |
+
self.pooling = pooling
|
173 |
+
self.features = nn.Sequential(
|
174 |
+
Block2D(1, 32),
|
175 |
+
nn.LPPool2d(4, (2, 4)),
|
176 |
+
Block2D(32, 128),
|
177 |
+
Block2D(128, 128),
|
178 |
+
nn.LPPool2d(4, (2, 4)),
|
179 |
+
Block2D(128, 128),
|
180 |
+
Block2D(128, 128),
|
181 |
+
nn.LPPool2d(4, (1, 4)),
|
182 |
+
nn.Dropout(0.3),
|
183 |
+
)
|
184 |
+
with torch.no_grad():
|
185 |
+
rnn_input_dim = self.features(
|
186 |
+
torch.randn(1, 1, 500, spec_dim)).shape
|
187 |
+
rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
|
188 |
+
|
189 |
+
self.gru = nn.GRU(rnn_input_dim,
|
190 |
+
128,
|
191 |
+
bidirectional=True,
|
192 |
+
batch_first=True)
|
193 |
+
self.apply(init)
|
194 |
+
|
195 |
+
def forward(self, input_dict):
|
196 |
+
x = input_dict["spec"]
|
197 |
+
lens = input_dict["spec_len"]
|
198 |
+
if "upsample" not in input_dict:
|
199 |
+
input_dict["upsample"] = False
|
200 |
+
lens = torch.as_tensor(copy.deepcopy(lens))
|
201 |
+
N, T, _ = x.shape
|
202 |
+
x = x.unsqueeze(1)
|
203 |
+
x = self.features(x)
|
204 |
+
x = x.transpose(1, 2).contiguous().flatten(-2)
|
205 |
+
x, _ = self.gru(x)
|
206 |
+
if input_dict["upsample"]:
|
207 |
+
x = nn.functional.interpolate(
|
208 |
+
x.transpose(1, 2),
|
209 |
+
T,
|
210 |
+
mode='linear',
|
211 |
+
align_corners=False).transpose(1, 2)
|
212 |
+
else:
|
213 |
+
lens //= 4
|
214 |
+
attn_emb = x
|
215 |
+
fc_emb = embedding_pooling(x, lens, self.pooling)
|
216 |
+
return {
|
217 |
+
"attn_emb": attn_emb,
|
218 |
+
"fc_emb": fc_emb,
|
219 |
+
"attn_emb_len": lens
|
220 |
+
}
|
221 |
+
|
222 |
+
|
223 |
+
def conv_conv_block(in_channel, out_channel):
|
224 |
+
return nn.Sequential(
|
225 |
+
nn.Conv2d(in_channel,
|
226 |
+
out_channel,
|
227 |
+
kernel_size=3,
|
228 |
+
bias=False,
|
229 |
+
padding=1),
|
230 |
+
nn.BatchNorm2d(out_channel),
|
231 |
+
nn.ReLU(True),
|
232 |
+
nn.Conv2d(out_channel,
|
233 |
+
out_channel,
|
234 |
+
kernel_size=3,
|
235 |
+
bias=False,
|
236 |
+
padding=1),
|
237 |
+
nn.BatchNorm2d(out_channel),
|
238 |
+
nn.ReLU(True)
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
class Cdur8Encoder(BaseEncoder):
|
243 |
+
|
244 |
+
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
|
245 |
+
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
246 |
+
self.pooling = pooling
|
247 |
+
self.features = nn.Sequential(
|
248 |
+
conv_conv_block(1, 64),
|
249 |
+
MMPool((2, 2)),
|
250 |
+
nn.Dropout(0.2, True),
|
251 |
+
conv_conv_block(64, 128),
|
252 |
+
MMPool((2, 2)),
|
253 |
+
nn.Dropout(0.2, True),
|
254 |
+
conv_conv_block(128, 256),
|
255 |
+
MMPool((1, 2)),
|
256 |
+
nn.Dropout(0.2, True),
|
257 |
+
conv_conv_block(256, 512),
|
258 |
+
MMPool((1, 2)),
|
259 |
+
nn.Dropout(0.2, True),
|
260 |
+
nn.AdaptiveAvgPool2d((None, 1)),
|
261 |
+
)
|
262 |
+
self.init_bn = nn.BatchNorm2d(spec_dim)
|
263 |
+
self.embedding = nn.Linear(512, 512)
|
264 |
+
self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True)
|
265 |
+
self.apply(init)
|
266 |
+
|
267 |
+
def forward(self, input_dict):
|
268 |
+
x = input_dict["spec"]
|
269 |
+
lens = input_dict["spec_len"]
|
270 |
+
lens = torch.as_tensor(copy.deepcopy(lens))
|
271 |
+
x = x.unsqueeze(1) # B x 1 x T x D
|
272 |
+
x = x.transpose(1, 3)
|
273 |
+
x = self.init_bn(x)
|
274 |
+
x = x.transpose(1, 3)
|
275 |
+
x = self.features(x)
|
276 |
+
x = x.transpose(1, 2).contiguous().flatten(-2)
|
277 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
278 |
+
x = F.relu_(self.embedding(x))
|
279 |
+
x, _ = self.gru(x)
|
280 |
+
attn_emb = x
|
281 |
+
lens //= 4
|
282 |
+
fc_emb = embedding_pooling(x, lens, self.pooling)
|
283 |
+
return {
|
284 |
+
"attn_emb": attn_emb,
|
285 |
+
"fc_emb": fc_emb,
|
286 |
+
"attn_emb_len": lens
|
287 |
+
}
|
288 |
+
|
289 |
+
|
290 |
+
class Cnn10Encoder(BaseEncoder):
|
291 |
+
|
292 |
+
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
|
293 |
+
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
294 |
+
self.features = nn.Sequential(
|
295 |
+
conv_conv_block(1, 64),
|
296 |
+
nn.AvgPool2d((2, 2)),
|
297 |
+
nn.Dropout(0.2, True),
|
298 |
+
conv_conv_block(64, 128),
|
299 |
+
nn.AvgPool2d((2, 2)),
|
300 |
+
nn.Dropout(0.2, True),
|
301 |
+
conv_conv_block(128, 256),
|
302 |
+
nn.AvgPool2d((2, 2)),
|
303 |
+
nn.Dropout(0.2, True),
|
304 |
+
conv_conv_block(256, 512),
|
305 |
+
nn.AvgPool2d((2, 2)),
|
306 |
+
nn.Dropout(0.2, True),
|
307 |
+
nn.AdaptiveAvgPool2d((None, 1)),
|
308 |
+
)
|
309 |
+
self.init_bn = nn.BatchNorm2d(spec_dim)
|
310 |
+
self.embedding = nn.Linear(512, 512)
|
311 |
+
self.apply(init)
|
312 |
+
|
313 |
+
def forward(self, input_dict):
|
314 |
+
x = input_dict["spec"]
|
315 |
+
lens = input_dict["spec_len"]
|
316 |
+
lens = torch.as_tensor(copy.deepcopy(lens))
|
317 |
+
x = x.unsqueeze(1) # [N, 1, T, D]
|
318 |
+
x = x.transpose(1, 3)
|
319 |
+
x = self.init_bn(x)
|
320 |
+
x = x.transpose(1, 3)
|
321 |
+
x = self.features(x) # [N, 512, T/16, 1]
|
322 |
+
x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512]
|
323 |
+
attn_emb = x
|
324 |
+
lens //= 16
|
325 |
+
fc_emb = embedding_pooling(x, lens, "mean+max")
|
326 |
+
fc_emb = F.dropout(fc_emb, p=0.5, training=self.training)
|
327 |
+
fc_emb = self.embedding(fc_emb)
|
328 |
+
fc_emb = F.relu_(fc_emb)
|
329 |
+
return {
|
330 |
+
"attn_emb": attn_emb,
|
331 |
+
"fc_emb": fc_emb,
|
332 |
+
"attn_emb_len": lens
|
333 |
+
}
|
334 |
+
|
335 |
+
|
336 |
+
class ConvBlock(nn.Module):
|
337 |
+
def __init__(self, in_channels, out_channels):
|
338 |
+
|
339 |
+
super(ConvBlock, self).__init__()
|
340 |
+
|
341 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
342 |
+
out_channels=out_channels,
|
343 |
+
kernel_size=(3, 3), stride=(1, 1),
|
344 |
+
padding=(1, 1), bias=False)
|
345 |
+
|
346 |
+
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
347 |
+
out_channels=out_channels,
|
348 |
+
kernel_size=(3, 3), stride=(1, 1),
|
349 |
+
padding=(1, 1), bias=False)
|
350 |
+
|
351 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
352 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
353 |
+
|
354 |
+
self.init_weight()
|
355 |
+
|
356 |
+
def init_weight(self):
|
357 |
+
init_layer(self.conv1)
|
358 |
+
init_layer(self.conv2)
|
359 |
+
init_bn(self.bn1)
|
360 |
+
init_bn(self.bn2)
|
361 |
+
|
362 |
+
|
363 |
+
def forward(self, input, pool_size=(2, 2), pool_type='avg'):
|
364 |
+
|
365 |
+
x = input
|
366 |
+
x = F.relu_(self.bn1(self.conv1(x)))
|
367 |
+
x = F.relu_(self.bn2(self.conv2(x)))
|
368 |
+
if pool_type == 'max':
|
369 |
+
x = F.max_pool2d(x, kernel_size=pool_size)
|
370 |
+
elif pool_type == 'avg':
|
371 |
+
x = F.avg_pool2d(x, kernel_size=pool_size)
|
372 |
+
elif pool_type == 'avg+max':
|
373 |
+
x1 = F.avg_pool2d(x, kernel_size=pool_size)
|
374 |
+
x2 = F.max_pool2d(x, kernel_size=pool_size)
|
375 |
+
x = x1 + x2
|
376 |
+
else:
|
377 |
+
raise Exception('Incorrect argument!')
|
378 |
+
|
379 |
+
return x
|
380 |
+
|
381 |
+
|
382 |
+
class Cnn14Encoder(nn.Module):
|
383 |
+
def __init__(self, sample_rate=32000):
|
384 |
+
super().__init__()
|
385 |
+
sr_to_fmax = {
|
386 |
+
32000: 14000,
|
387 |
+
16000: 8000
|
388 |
+
}
|
389 |
+
# Logmel spectrogram extractor
|
390 |
+
self.melspec_extractor = transforms.MelSpectrogram(
|
391 |
+
sample_rate=sample_rate,
|
392 |
+
n_fft=32 * sample_rate // 1000,
|
393 |
+
win_length=32 * sample_rate // 1000,
|
394 |
+
hop_length=10 * sample_rate // 1000,
|
395 |
+
f_min=50,
|
396 |
+
f_max=sr_to_fmax[sample_rate],
|
397 |
+
n_mels=64,
|
398 |
+
norm="slaney",
|
399 |
+
mel_scale="slaney"
|
400 |
+
)
|
401 |
+
self.hop_length = 10 * sample_rate // 1000
|
402 |
+
self.db_transform = transforms.AmplitudeToDB()
|
403 |
+
# Spec augmenter
|
404 |
+
self.spec_augmenter = SpecAugmentation(time_drop_width=64,
|
405 |
+
time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2)
|
406 |
+
|
407 |
+
self.bn0 = nn.BatchNorm2d(64)
|
408 |
+
|
409 |
+
self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
|
410 |
+
self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
|
411 |
+
self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
|
412 |
+
self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
|
413 |
+
self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
|
414 |
+
self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
|
415 |
+
|
416 |
+
self.downsample_ratio = 32
|
417 |
+
|
418 |
+
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
419 |
+
|
420 |
+
self.init_weight()
|
421 |
+
|
422 |
+
def init_weight(self):
|
423 |
+
init_bn(self.bn0)
|
424 |
+
init_layer(self.fc1)
|
425 |
+
|
426 |
+
def load_pretrained(self, pretrained):
|
427 |
+
checkpoint = torch.load(pretrained, map_location="cpu")
|
428 |
+
|
429 |
+
if "model" in checkpoint:
|
430 |
+
state_keys = checkpoint["model"].keys()
|
431 |
+
backbone = False
|
432 |
+
for key in state_keys:
|
433 |
+
if key.startswith("backbone."):
|
434 |
+
backbone = True
|
435 |
+
break
|
436 |
+
|
437 |
+
if backbone: # COLA
|
438 |
+
state_dict = {}
|
439 |
+
for key, value in checkpoint["model"].items():
|
440 |
+
if key.startswith("backbone."):
|
441 |
+
model_key = key.replace("backbone.", "")
|
442 |
+
state_dict[model_key] = value
|
443 |
+
else: # PANNs
|
444 |
+
state_dict = checkpoint["model"]
|
445 |
+
elif "state_dict" in checkpoint: # CLAP
|
446 |
+
state_dict = checkpoint["state_dict"]
|
447 |
+
state_dict_keys = list(filter(
|
448 |
+
lambda x: "audio_encoder" in x, state_dict.keys()))
|
449 |
+
state_dict = {
|
450 |
+
key.replace('audio_encoder.', ''): state_dict[key]
|
451 |
+
for key in state_dict_keys
|
452 |
+
}
|
453 |
+
else:
|
454 |
+
raise Exception("Unkown checkpoint format")
|
455 |
+
|
456 |
+
model_dict = self.state_dict()
|
457 |
+
pretrained_dict = {
|
458 |
+
k: v for k, v in state_dict.items() if (k in model_dict) and (
|
459 |
+
model_dict[k].shape == v.shape)
|
460 |
+
}
|
461 |
+
model_dict.update(pretrained_dict)
|
462 |
+
self.load_state_dict(model_dict, strict=True)
|
463 |
+
|
464 |
+
def forward(self, input_dict):
|
465 |
+
"""
|
466 |
+
Input: (batch_size, n_samples)"""
|
467 |
+
waveform = input_dict["wav"]
|
468 |
+
wave_length = input_dict["wav_len"]
|
469 |
+
specaug = input_dict["specaug"]
|
470 |
+
x = self.melspec_extractor(waveform)
|
471 |
+
x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
|
472 |
+
x = x.transpose(1, 2)
|
473 |
+
x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
|
474 |
+
|
475 |
+
# SpecAugment
|
476 |
+
if self.training and specaug:
|
477 |
+
x = self.spec_augmenter(x)
|
478 |
+
|
479 |
+
x = x.transpose(1, 3)
|
480 |
+
x = self.bn0(x)
|
481 |
+
x = x.transpose(1, 3)
|
482 |
+
|
483 |
+
x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
|
484 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
485 |
+
x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
|
486 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
487 |
+
x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
|
488 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
489 |
+
x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
|
490 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
491 |
+
x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
|
492 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
493 |
+
x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
|
494 |
+
x = F.dropout(x, p=0.2, training=self.training)
|
495 |
+
x = torch.mean(x, dim=3)
|
496 |
+
attn_emb = x.transpose(1, 2)
|
497 |
+
|
498 |
+
wave_length = torch.as_tensor(wave_length)
|
499 |
+
feat_length = torch.div(wave_length, self.hop_length,
|
500 |
+
rounding_mode="floor") + 1
|
501 |
+
feat_length = torch.div(feat_length, self.downsample_ratio,
|
502 |
+
rounding_mode="floor")
|
503 |
+
x_max = max_with_lens(attn_emb, feat_length)
|
504 |
+
x_mean = mean_with_lens(attn_emb, feat_length)
|
505 |
+
x = x_max + x_mean
|
506 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
507 |
+
x = F.relu_(self.fc1(x))
|
508 |
+
fc_emb = F.dropout(x, p=0.5, training=self.training)
|
509 |
+
|
510 |
+
output_dict = {
|
511 |
+
'fc_emb': fc_emb,
|
512 |
+
'attn_emb': attn_emb,
|
513 |
+
'attn_emb_len': feat_length
|
514 |
+
}
|
515 |
+
|
516 |
+
return output_dict
|
517 |
+
|
518 |
+
|
519 |
+
class RnnEncoder(BaseEncoder):
|
520 |
+
|
521 |
+
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim,
|
522 |
+
pooling="mean", **kwargs):
|
523 |
+
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
524 |
+
self.pooling = pooling
|
525 |
+
self.hidden_size = kwargs.get('hidden_size', 512)
|
526 |
+
self.bidirectional = kwargs.get('bidirectional', False)
|
527 |
+
self.num_layers = kwargs.get('num_layers', 1)
|
528 |
+
self.dropout = kwargs.get('dropout', 0.2)
|
529 |
+
self.rnn_type = kwargs.get('rnn_type', "GRU")
|
530 |
+
self.in_bn = kwargs.get('in_bn', False)
|
531 |
+
self.embed_dim = self.hidden_size * (self.bidirectional + 1)
|
532 |
+
self.network = getattr(nn, self.rnn_type)(
|
533 |
+
attn_feat_dim,
|
534 |
+
self.hidden_size,
|
535 |
+
num_layers=self.num_layers,
|
536 |
+
bidirectional=self.bidirectional,
|
537 |
+
dropout=self.dropout,
|
538 |
+
batch_first=True)
|
539 |
+
if self.in_bn:
|
540 |
+
self.bn = nn.BatchNorm1d(self.embed_dim)
|
541 |
+
self.apply(init)
|
542 |
+
|
543 |
+
def forward(self, input_dict):
|
544 |
+
x = input_dict["attn"]
|
545 |
+
lens = input_dict["attn_len"]
|
546 |
+
lens = torch.as_tensor(lens)
|
547 |
+
# x: [N, T, E]
|
548 |
+
if self.in_bn:
|
549 |
+
x = pack_wrapper(self.bn, x, lens)
|
550 |
+
out = pack_wrapper(self.network, x, lens)
|
551 |
+
# out: [N, T, hidden]
|
552 |
+
attn_emb = out
|
553 |
+
fc_emb = embedding_pooling(out, lens, self.pooling)
|
554 |
+
return {
|
555 |
+
"attn_emb": attn_emb,
|
556 |
+
"fc_emb": fc_emb,
|
557 |
+
"attn_emb_len": lens
|
558 |
+
}
|
559 |
+
|
560 |
+
|
561 |
+
class Cnn14RnnEncoder(nn.Module):
|
562 |
+
def __init__(self, sample_rate=32000, pretrained=None,
|
563 |
+
freeze_cnn=False, freeze_cnn_bn=False,
|
564 |
+
pooling="mean", **kwargs):
|
565 |
+
super().__init__()
|
566 |
+
self.cnn = Cnn14Encoder(sample_rate)
|
567 |
+
self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs)
|
568 |
+
if pretrained is not None:
|
569 |
+
self.cnn.load_pretrained(pretrained)
|
570 |
+
if freeze_cnn:
|
571 |
+
assert pretrained is not None, "cnn is not pretrained but frozen"
|
572 |
+
for param in self.cnn.parameters():
|
573 |
+
param.requires_grad = False
|
574 |
+
self.freeze_cnn_bn = freeze_cnn_bn
|
575 |
+
|
576 |
+
def train(self, mode):
|
577 |
+
super().train(mode=mode)
|
578 |
+
if self.freeze_cnn_bn:
|
579 |
+
def bn_eval(module):
|
580 |
+
class_name = module.__class__.__name__
|
581 |
+
if class_name.find("BatchNorm") != -1:
|
582 |
+
module.eval()
|
583 |
+
self.cnn.apply(bn_eval)
|
584 |
+
return self
|
585 |
+
|
586 |
+
def forward(self, input_dict):
|
587 |
+
output_dict = self.cnn(input_dict)
|
588 |
+
output_dict["attn"] = output_dict["attn_emb"]
|
589 |
+
output_dict["attn_len"] = output_dict["attn_emb_len"]
|
590 |
+
del output_dict["attn_emb"], output_dict["attn_emb_len"]
|
591 |
+
output_dict = self.rnn(output_dict)
|
592 |
+
return output_dict
|
593 |
+
|
594 |
+
|
595 |
+
class TransformerEncoder(BaseEncoder):
|
596 |
+
|
597 |
+
def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs):
|
598 |
+
super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
|
599 |
+
self.d_model = d_model
|
600 |
+
dropout = kwargs.get("dropout", 0.2)
|
601 |
+
self.nhead = kwargs.get("nhead", self.d_model // 64)
|
602 |
+
self.nlayers = kwargs.get("nlayers", 2)
|
603 |
+
self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
|
604 |
+
|
605 |
+
self.attn_proj = nn.Sequential(
|
606 |
+
nn.Linear(attn_feat_dim, self.d_model),
|
607 |
+
nn.ReLU(),
|
608 |
+
nn.Dropout(dropout),
|
609 |
+
nn.LayerNorm(self.d_model)
|
610 |
+
)
|
611 |
+
layer = nn.TransformerEncoderLayer(d_model=self.d_model,
|
612 |
+
nhead=self.nhead,
|
613 |
+
dim_feedforward=self.dim_feedforward,
|
614 |
+
dropout=dropout)
|
615 |
+
self.model = nn.TransformerEncoder(layer, self.nlayers)
|
616 |
+
self.cls_token = nn.Parameter(torch.zeros(d_model))
|
617 |
+
self.init_params()
|
618 |
+
|
619 |
+
def init_params(self):
|
620 |
+
for p in self.parameters():
|
621 |
+
if p.dim() > 1:
|
622 |
+
nn.init.xavier_uniform_(p)
|
623 |
+
|
624 |
+
def forward(self, input_dict):
|
625 |
+
attn_feat = input_dict["attn"]
|
626 |
+
attn_feat_len = input_dict["attn_len"]
|
627 |
+
attn_feat_len = torch.as_tensor(attn_feat_len)
|
628 |
+
|
629 |
+
attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model]
|
630 |
+
|
631 |
+
cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat(
|
632 |
+
attn_feat.size(0), 1, 1)
|
633 |
+
attn_feat = torch.cat((cls_emb, attn_feat), dim=1)
|
634 |
+
attn_feat = attn_feat.transpose(0, 1)
|
635 |
+
|
636 |
+
attn_feat_len += 1
|
637 |
+
src_key_padding_mask = ~generate_length_mask(
|
638 |
+
attn_feat_len, attn_feat.size(0)).to(attn_feat.device)
|
639 |
+
output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask)
|
640 |
+
|
641 |
+
attn_emb = output.transpose(0, 1)
|
642 |
+
fc_emb = attn_emb[:, 0]
|
643 |
+
return {
|
644 |
+
"attn_emb": attn_emb,
|
645 |
+
"fc_emb": fc_emb,
|
646 |
+
"attn_emb_len": attn_feat_len
|
647 |
+
}
|
648 |
+
|
649 |
+
|
650 |
+
class Cnn14TransformerEncoder(nn.Module):
|
651 |
+
def __init__(self, sample_rate=32000, pretrained=None,
|
652 |
+
freeze_cnn=False, freeze_cnn_bn=False,
|
653 |
+
d_model="mean", **kwargs):
|
654 |
+
super().__init__()
|
655 |
+
self.cnn = Cnn14Encoder(sample_rate)
|
656 |
+
self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs)
|
657 |
+
if pretrained is not None:
|
658 |
+
self.cnn.load_pretrained(pretrained)
|
659 |
+
if freeze_cnn:
|
660 |
+
assert pretrained is not None, "cnn is not pretrained but frozen"
|
661 |
+
for param in self.cnn.parameters():
|
662 |
+
param.requires_grad = False
|
663 |
+
self.freeze_cnn_bn = freeze_cnn_bn
|
664 |
+
|
665 |
+
def train(self, mode):
|
666 |
+
super().train(mode=mode)
|
667 |
+
if self.freeze_cnn_bn:
|
668 |
+
def bn_eval(module):
|
669 |
+
class_name = module.__class__.__name__
|
670 |
+
if class_name.find("BatchNorm") != -1:
|
671 |
+
module.eval()
|
672 |
+
self.cnn.apply(bn_eval)
|
673 |
+
return self
|
674 |
+
|
675 |
+
def forward(self, input_dict):
|
676 |
+
output_dict = self.cnn(input_dict)
|
677 |
+
output_dict["attn"] = output_dict["attn_emb"]
|
678 |
+
output_dict["attn_len"] = output_dict["attn_emb_len"]
|
679 |
+
del output_dict["attn_emb"], output_dict["attn_emb_len"]
|
680 |
+
output_dict = self.trm(output_dict)
|
681 |
+
return output_dict
|
682 |
+
|
683 |
+
|
684 |
+
|
685 |
+
|
686 |
+
|
audio_to_text/captioning/models/transformer_model.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .base_model import CaptionModel
|
7 |
+
from .utils import repeat_tensor
|
8 |
+
import audio_to_text.captioning.models.decoder
|
9 |
+
|
10 |
+
|
11 |
+
class TransformerModel(CaptionModel):
|
12 |
+
|
13 |
+
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
14 |
+
if not hasattr(self, "compatible_decoders"):
|
15 |
+
self.compatible_decoders = (
|
16 |
+
audio_to_text.captioning.models.decoder.TransformerDecoder,
|
17 |
+
)
|
18 |
+
super().__init__(encoder, decoder, **kwargs)
|
19 |
+
|
20 |
+
def seq_forward(self, input_dict):
|
21 |
+
cap = input_dict["cap"]
|
22 |
+
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
23 |
+
cap_padding_mask = cap_padding_mask[:, :-1]
|
24 |
+
output = self.decoder(
|
25 |
+
{
|
26 |
+
"word": cap[:, :-1],
|
27 |
+
"attn_emb": input_dict["attn_emb"],
|
28 |
+
"attn_emb_len": input_dict["attn_emb_len"],
|
29 |
+
"cap_padding_mask": cap_padding_mask
|
30 |
+
}
|
31 |
+
)
|
32 |
+
return output
|
33 |
+
|
34 |
+
def prepare_decoder_input(self, input_dict, output):
|
35 |
+
decoder_input = {
|
36 |
+
"attn_emb": input_dict["attn_emb"],
|
37 |
+
"attn_emb_len": input_dict["attn_emb_len"]
|
38 |
+
}
|
39 |
+
t = input_dict["t"]
|
40 |
+
|
41 |
+
###############
|
42 |
+
# determine input word
|
43 |
+
################
|
44 |
+
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
|
45 |
+
word = input_dict["cap"][:, :t+1]
|
46 |
+
else:
|
47 |
+
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
|
48 |
+
if t == 0:
|
49 |
+
word = start_word
|
50 |
+
else:
|
51 |
+
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
|
52 |
+
# word: [N, T]
|
53 |
+
decoder_input["word"] = word
|
54 |
+
|
55 |
+
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
|
56 |
+
decoder_input["cap_padding_mask"] = cap_padding_mask
|
57 |
+
return decoder_input
|
58 |
+
|
59 |
+
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
60 |
+
decoder_input = {}
|
61 |
+
t = input_dict["t"]
|
62 |
+
i = input_dict["sample_idx"]
|
63 |
+
beam_size = input_dict["beam_size"]
|
64 |
+
###############
|
65 |
+
# prepare attn embeds
|
66 |
+
################
|
67 |
+
if t == 0:
|
68 |
+
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
|
69 |
+
attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
|
70 |
+
output_i["attn_emb"] = attn_emb
|
71 |
+
output_i["attn_emb_len"] = attn_emb_len
|
72 |
+
decoder_input["attn_emb"] = output_i["attn_emb"]
|
73 |
+
decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
|
74 |
+
###############
|
75 |
+
# determine input word
|
76 |
+
################
|
77 |
+
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
|
78 |
+
if t == 0:
|
79 |
+
word = start_word
|
80 |
+
else:
|
81 |
+
word = torch.cat((start_word, output_i["seq"]), dim=-1)
|
82 |
+
decoder_input["word"] = word
|
83 |
+
cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
|
84 |
+
decoder_input["cap_padding_mask"] = cap_padding_mask
|
85 |
+
|
86 |
+
return decoder_input
|
87 |
+
|
88 |
+
|
89 |
+
class M2TransformerModel(CaptionModel):
|
90 |
+
|
91 |
+
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
92 |
+
if not hasattr(self, "compatible_decoders"):
|
93 |
+
self.compatible_decoders = (
|
94 |
+
captioning.models.decoder.M2TransformerDecoder,
|
95 |
+
)
|
96 |
+
super().__init__(encoder, decoder, **kwargs)
|
97 |
+
self.check_encoder_compatibility()
|
98 |
+
|
99 |
+
def check_encoder_compatibility(self):
|
100 |
+
assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \
|
101 |
+
f"only M2TransformerModel is compatible with {self.__class__.__name__}"
|
102 |
+
|
103 |
+
|
104 |
+
def seq_forward(self, input_dict):
|
105 |
+
cap = input_dict["cap"]
|
106 |
+
output = self.decoder(
|
107 |
+
{
|
108 |
+
"word": cap[:, :-1],
|
109 |
+
"attn_emb": input_dict["attn_emb"],
|
110 |
+
"attn_emb_mask": input_dict["attn_emb_mask"],
|
111 |
+
}
|
112 |
+
)
|
113 |
+
return output
|
114 |
+
|
115 |
+
def prepare_decoder_input(self, input_dict, output):
|
116 |
+
decoder_input = {
|
117 |
+
"attn_emb": input_dict["attn_emb"],
|
118 |
+
"attn_emb_mask": input_dict["attn_emb_mask"]
|
119 |
+
}
|
120 |
+
t = input_dict["t"]
|
121 |
+
|
122 |
+
###############
|
123 |
+
# determine input word
|
124 |
+
################
|
125 |
+
if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
|
126 |
+
word = input_dict["cap"][:, :t+1]
|
127 |
+
else:
|
128 |
+
start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
|
129 |
+
if t == 0:
|
130 |
+
word = start_word
|
131 |
+
else:
|
132 |
+
word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
|
133 |
+
# word: [N, T]
|
134 |
+
decoder_input["word"] = word
|
135 |
+
|
136 |
+
return decoder_input
|
137 |
+
|
138 |
+
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
139 |
+
decoder_input = {}
|
140 |
+
t = input_dict["t"]
|
141 |
+
i = input_dict["sample_idx"]
|
142 |
+
beam_size = input_dict["beam_size"]
|
143 |
+
###############
|
144 |
+
# prepare attn embeds
|
145 |
+
################
|
146 |
+
if t == 0:
|
147 |
+
attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
|
148 |
+
attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
|
149 |
+
output_i["attn_emb"] = attn_emb
|
150 |
+
output_i["attn_emb_mask"] = attn_emb_mask
|
151 |
+
decoder_input["attn_emb"] = output_i["attn_emb"]
|
152 |
+
decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
|
153 |
+
###############
|
154 |
+
# determine input word
|
155 |
+
################
|
156 |
+
start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
|
157 |
+
if t == 0:
|
158 |
+
word = start_word
|
159 |
+
else:
|
160 |
+
word = torch.cat((start_word, output_i["seq"]), dim=-1)
|
161 |
+
decoder_input["word"] = word
|
162 |
+
|
163 |
+
return decoder_input
|
164 |
+
|
165 |
+
|
166 |
+
class EventEncoder(nn.Module):
|
167 |
+
"""
|
168 |
+
Encode the Label information in AudioCaps and AudioSet
|
169 |
+
"""
|
170 |
+
def __init__(self, emb_dim, vocab_size=527):
|
171 |
+
super(EventEncoder, self).__init__()
|
172 |
+
self.label_embedding = nn.Parameter(
|
173 |
+
torch.randn((vocab_size, emb_dim)), requires_grad=True)
|
174 |
+
|
175 |
+
def forward(self, word_idxs):
|
176 |
+
indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
|
177 |
+
embeddings = indices @ self.label_embedding
|
178 |
+
return embeddings
|
179 |
+
|
180 |
+
|
181 |
+
class EventCondTransformerModel(TransformerModel):
|
182 |
+
|
183 |
+
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
184 |
+
if not hasattr(self, "compatible_decoders"):
|
185 |
+
self.compatible_decoders = (
|
186 |
+
captioning.models.decoder.EventTransformerDecoder,
|
187 |
+
)
|
188 |
+
super().__init__(encoder, decoder, **kwargs)
|
189 |
+
self.label_encoder = EventEncoder(decoder.emb_dim, 527)
|
190 |
+
self.train_forward_keys += ["events"]
|
191 |
+
self.inference_forward_keys += ["events"]
|
192 |
+
|
193 |
+
# def seq_forward(self, input_dict):
|
194 |
+
# cap = input_dict["cap"]
|
195 |
+
# cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
196 |
+
# cap_padding_mask = cap_padding_mask[:, :-1]
|
197 |
+
# output = self.decoder(
|
198 |
+
# {
|
199 |
+
# "word": cap[:, :-1],
|
200 |
+
# "attn_emb": input_dict["attn_emb"],
|
201 |
+
# "attn_emb_len": input_dict["attn_emb_len"],
|
202 |
+
# "cap_padding_mask": cap_padding_mask
|
203 |
+
# }
|
204 |
+
# )
|
205 |
+
# return output
|
206 |
+
|
207 |
+
def prepare_decoder_input(self, input_dict, output):
|
208 |
+
decoder_input = super().prepare_decoder_input(input_dict, output)
|
209 |
+
decoder_input["events"] = self.label_encoder(input_dict["events"])
|
210 |
+
return decoder_input
|
211 |
+
|
212 |
+
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
213 |
+
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
|
214 |
+
t = input_dict["t"]
|
215 |
+
i = input_dict["sample_idx"]
|
216 |
+
beam_size = input_dict["beam_size"]
|
217 |
+
if t == 0:
|
218 |
+
output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
|
219 |
+
decoder_input["events"] = output_i["events"]
|
220 |
+
return decoder_input
|
221 |
+
|
222 |
+
|
223 |
+
class KeywordCondTransformerModel(TransformerModel):
|
224 |
+
|
225 |
+
def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
|
226 |
+
if not hasattr(self, "compatible_decoders"):
|
227 |
+
self.compatible_decoders = (
|
228 |
+
captioning.models.decoder.KeywordProbTransformerDecoder,
|
229 |
+
)
|
230 |
+
super().__init__(encoder, decoder, **kwargs)
|
231 |
+
self.train_forward_keys += ["keyword"]
|
232 |
+
self.inference_forward_keys += ["keyword"]
|
233 |
+
|
234 |
+
def seq_forward(self, input_dict):
|
235 |
+
cap = input_dict["cap"]
|
236 |
+
cap_padding_mask = (cap == self.pad_idx).to(cap.device)
|
237 |
+
cap_padding_mask = cap_padding_mask[:, :-1]
|
238 |
+
keyword = input_dict["keyword"]
|
239 |
+
output = self.decoder(
|
240 |
+
{
|
241 |
+
"word": cap[:, :-1],
|
242 |
+
"attn_emb": input_dict["attn_emb"],
|
243 |
+
"attn_emb_len": input_dict["attn_emb_len"],
|
244 |
+
"keyword": keyword,
|
245 |
+
"cap_padding_mask": cap_padding_mask
|
246 |
+
}
|
247 |
+
)
|
248 |
+
return output
|
249 |
+
|
250 |
+
def prepare_decoder_input(self, input_dict, output):
|
251 |
+
decoder_input = super().prepare_decoder_input(input_dict, output)
|
252 |
+
decoder_input["keyword"] = input_dict["keyword"]
|
253 |
+
return decoder_input
|
254 |
+
|
255 |
+
def prepare_beamsearch_decoder_input(self, input_dict, output_i):
|
256 |
+
decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
|
257 |
+
t = input_dict["t"]
|
258 |
+
i = input_dict["sample_idx"]
|
259 |
+
beam_size = input_dict["beam_size"]
|
260 |
+
if t == 0:
|
261 |
+
output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
|
262 |
+
beam_size)
|
263 |
+
decoder_input["keyword"] = output_i["keyword"]
|
264 |
+
return decoder_input
|
265 |
+
|
audio_to_text/captioning/models/utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
|
8 |
+
|
9 |
+
|
10 |
+
def sort_pack_padded_sequence(input, lengths):
|
11 |
+
sorted_lengths, indices = torch.sort(lengths, descending=True)
|
12 |
+
tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
|
13 |
+
inv_ix = indices.clone()
|
14 |
+
inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
|
15 |
+
return tmp, inv_ix
|
16 |
+
|
17 |
+
def pad_unsort_packed_sequence(input, inv_ix):
|
18 |
+
tmp, _ = pad_packed_sequence(input, batch_first=True)
|
19 |
+
tmp = tmp[inv_ix]
|
20 |
+
return tmp
|
21 |
+
|
22 |
+
def pack_wrapper(module, attn_feats, attn_feat_lens):
|
23 |
+
packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
|
24 |
+
if isinstance(module, torch.nn.RNNBase):
|
25 |
+
return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
|
26 |
+
else:
|
27 |
+
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
|
28 |
+
|
29 |
+
def generate_length_mask(lens, max_length=None):
|
30 |
+
lens = torch.as_tensor(lens)
|
31 |
+
N = lens.size(0)
|
32 |
+
if max_length is None:
|
33 |
+
max_length = max(lens)
|
34 |
+
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
|
35 |
+
idxs = idxs.to(lens.device)
|
36 |
+
mask = (idxs < lens.view(-1, 1))
|
37 |
+
return mask
|
38 |
+
|
39 |
+
def mean_with_lens(features, lens):
|
40 |
+
"""
|
41 |
+
features: [N, T, ...] (assume the second dimension represents length)
|
42 |
+
lens: [N,]
|
43 |
+
"""
|
44 |
+
lens = torch.as_tensor(lens)
|
45 |
+
if max(lens) != features.size(1):
|
46 |
+
max_length = features.size(1)
|
47 |
+
mask = generate_length_mask(lens, max_length)
|
48 |
+
else:
|
49 |
+
mask = generate_length_mask(lens)
|
50 |
+
mask = mask.to(features.device) # [N, T]
|
51 |
+
|
52 |
+
while mask.ndim < features.ndim:
|
53 |
+
mask = mask.unsqueeze(-1)
|
54 |
+
feature_mean = features * mask
|
55 |
+
feature_mean = feature_mean.sum(1)
|
56 |
+
while lens.ndim < feature_mean.ndim:
|
57 |
+
lens = lens.unsqueeze(1)
|
58 |
+
feature_mean = feature_mean / lens.to(features.device)
|
59 |
+
# feature_mean = features * mask.unsqueeze(-1)
|
60 |
+
# feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
|
61 |
+
return feature_mean
|
62 |
+
|
63 |
+
def max_with_lens(features, lens):
|
64 |
+
"""
|
65 |
+
features: [N, T, ...] (assume the second dimension represents length)
|
66 |
+
lens: [N,]
|
67 |
+
"""
|
68 |
+
lens = torch.as_tensor(lens)
|
69 |
+
mask = generate_length_mask(lens).to(features.device) # [N, T]
|
70 |
+
|
71 |
+
feature_max = features.clone()
|
72 |
+
feature_max[~mask] = float("-inf")
|
73 |
+
feature_max, _ = feature_max.max(1)
|
74 |
+
return feature_max
|
75 |
+
|
76 |
+
def repeat_tensor(x, n):
|
77 |
+
return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
|
78 |
+
|
79 |
+
def init(m, method="kaiming"):
|
80 |
+
if isinstance(m, (nn.Conv2d, nn.Conv1d)):
|
81 |
+
if method == "kaiming":
|
82 |
+
nn.init.kaiming_uniform_(m.weight)
|
83 |
+
elif method == "xavier":
|
84 |
+
nn.init.xavier_uniform_(m.weight)
|
85 |
+
else:
|
86 |
+
raise Exception(f"initialization method {method} not supported")
|
87 |
+
if m.bias is not None:
|
88 |
+
nn.init.constant_(m.bias, 0)
|
89 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
90 |
+
nn.init.constant_(m.weight, 1)
|
91 |
+
if m.bias is not None:
|
92 |
+
nn.init.constant_(m.bias, 0)
|
93 |
+
elif isinstance(m, nn.Linear):
|
94 |
+
if method == "kaiming":
|
95 |
+
nn.init.kaiming_uniform_(m.weight)
|
96 |
+
elif method == "xavier":
|
97 |
+
nn.init.xavier_uniform_(m.weight)
|
98 |
+
else:
|
99 |
+
raise Exception(f"initialization method {method} not supported")
|
100 |
+
if m.bias is not None:
|
101 |
+
nn.init.constant_(m.bias, 0)
|
102 |
+
elif isinstance(m, nn.Embedding):
|
103 |
+
if method == "kaiming":
|
104 |
+
nn.init.kaiming_uniform_(m.weight)
|
105 |
+
elif method == "xavier":
|
106 |
+
nn.init.xavier_uniform_(m.weight)
|
107 |
+
else:
|
108 |
+
raise Exception(f"initialization method {method} not supported")
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
class PositionalEncoding(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, d_model, dropout=0.1, max_len=100):
|
116 |
+
super(PositionalEncoding, self).__init__()
|
117 |
+
self.dropout = nn.Dropout(p=dropout)
|
118 |
+
|
119 |
+
pe = torch.zeros(max_len, d_model)
|
120 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
121 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
|
122 |
+
(-math.log(10000.0) / d_model))
|
123 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
124 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
125 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
126 |
+
# self.register_buffer("pe", pe)
|
127 |
+
self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
# x: [T, N, E]
|
131 |
+
x = x + self.pe[:x.size(0), :]
|
132 |
+
return self.dropout(x)
|
audio_to_text/captioning/utils/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Utils
|
2 |
+
|
3 |
+
Scripts in this directory are used as utility functions.
|
4 |
+
|
5 |
+
## BERT Pretrained Embeddings
|
6 |
+
|
7 |
+
You can load pretrained word embeddings in Google [BERT](https://github.com/google-research/bert#pre-trained-models) instead of training word embeddings from scratch. The scripts in `utils/bert` need a BERT server in the background. We use BERT server from [bert-as-service](https://github.com/hanxiao/bert-as-service).
|
8 |
+
|
9 |
+
To use bert-as-service, you need to first install the repository. It is recommended that you create a new environment with Tensorflow 1.3 to run BERT server since it is incompatible with Tensorflow 2.x.
|
10 |
+
|
11 |
+
After successful installation of [bert-as-service](https://github.com/hanxiao/bert-as-service), downloading and running the BERT server needs to execute:
|
12 |
+
|
13 |
+
```bash
|
14 |
+
bash scripts/prepare_bert_server.sh <path-to-server> <num-workers> zh
|
15 |
+
```
|
16 |
+
|
17 |
+
By default, server based on BERT base Chinese model is running in the background. You can change to other models by changing corresponding model name and path in `scripts/prepare_bert_server.sh`.
|
18 |
+
|
19 |
+
To extract BERT word embeddings, you need to execute `utils/bert/create_word_embedding.py`.
|
audio_to_text/captioning/utils/__init__.py
ADDED
File without changes
|
audio_to_text/captioning/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (174 Bytes). View file
|
|
audio_to_text/captioning/utils/__pycache__/train_util.cpython-38.pyc
ADDED
Binary file (5.75 kB). View file
|
|
audio_to_text/captioning/utils/bert/create_sent_embedding.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import fire
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class EmbeddingExtractor(object):
|
9 |
+
|
10 |
+
def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
lang2model = {
|
13 |
+
"zh": "distiluse-base-multilingual-cased",
|
14 |
+
"en": "bert-base-nli-mean-tokens"
|
15 |
+
}
|
16 |
+
lang = "zh" if zh else "en"
|
17 |
+
model = SentenceTransformer(lang2model[lang])
|
18 |
+
|
19 |
+
self.extract(caption_file, model, output, dev)
|
20 |
+
|
21 |
+
def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
|
22 |
+
from bert_serving.client import BertClient
|
23 |
+
client = BertClient(ip)
|
24 |
+
|
25 |
+
self.extract(caption_file, client, output, dev)
|
26 |
+
|
27 |
+
def extract(self, caption_file: str, model, output, dev: bool):
|
28 |
+
caption_df = pd.read_json(caption_file, dtype={"key": str})
|
29 |
+
embeddings = {}
|
30 |
+
|
31 |
+
if dev:
|
32 |
+
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
|
33 |
+
for idx, row in caption_df.iterrows():
|
34 |
+
caption = row["caption"]
|
35 |
+
key = row["key"]
|
36 |
+
cap_idx = row["caption_index"]
|
37 |
+
embedding = model.encode([caption])
|
38 |
+
embedding = np.array(embedding).reshape(-1)
|
39 |
+
embeddings[f"{key}_{cap_idx}"] = embedding
|
40 |
+
pbar.update()
|
41 |
+
|
42 |
+
else:
|
43 |
+
dump = {}
|
44 |
+
|
45 |
+
with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
|
46 |
+
for idx, row in caption_df.iterrows():
|
47 |
+
key = row["key"]
|
48 |
+
caption = row["caption"]
|
49 |
+
value = np.array(model.encode([caption])).reshape(-1)
|
50 |
+
|
51 |
+
if key not in embeddings.keys():
|
52 |
+
embeddings[key] = [value]
|
53 |
+
else:
|
54 |
+
embeddings[key].append(value)
|
55 |
+
|
56 |
+
pbar.update()
|
57 |
+
|
58 |
+
for key in embeddings:
|
59 |
+
dump[key] = np.stack(embeddings[key])
|
60 |
+
|
61 |
+
embeddings = dump
|
62 |
+
|
63 |
+
with open(output, "wb") as f:
|
64 |
+
pickle.dump(embeddings, f)
|
65 |
+
|
66 |
+
def extract_sbert(self,
|
67 |
+
input_json: str,
|
68 |
+
output: str):
|
69 |
+
from sentence_transformers import SentenceTransformer
|
70 |
+
import json
|
71 |
+
import torch
|
72 |
+
from h5py import File
|
73 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
74 |
+
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
|
75 |
+
model = model.to(device)
|
76 |
+
model.eval()
|
77 |
+
|
78 |
+
data = json.load(open(input_json))["audios"]
|
79 |
+
with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
|
80 |
+
for sample in data:
|
81 |
+
audio_id = sample["audio_id"]
|
82 |
+
for cap in sample["captions"]:
|
83 |
+
cap_id = cap["cap_id"]
|
84 |
+
store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
|
85 |
+
pbar.update()
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
fire.Fire(EmbeddingExtractor)
|
audio_to_text/captioning/utils/bert/create_word_embedding.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
from bert_serving.client import BertClient
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import fire
|
10 |
+
import torch
|
11 |
+
|
12 |
+
sys.path.append(os.getcwd())
|
13 |
+
from utils.build_vocab import Vocabulary
|
14 |
+
|
15 |
+
def main(vocab_file: str, output: str, server_hostname: str):
|
16 |
+
client = BertClient(ip=server_hostname)
|
17 |
+
vocabulary = torch.load(vocab_file)
|
18 |
+
vocab_size = len(vocabulary)
|
19 |
+
|
20 |
+
fake_embedding = client.encode(["test"]).reshape(-1)
|
21 |
+
embed_size = fake_embedding.shape[0]
|
22 |
+
|
23 |
+
print("Encoding words into embeddings with size: ", embed_size)
|
24 |
+
|
25 |
+
embeddings = np.empty((vocab_size, embed_size))
|
26 |
+
for i in tqdm(range(len(embeddings)), ascii=True):
|
27 |
+
embeddings[i] = client.encode([vocabulary.idx2word[i]])
|
28 |
+
np.save(output, embeddings)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
fire.Fire(main)
|
33 |
+
|
34 |
+
|
audio_to_text/captioning/utils/build_vocab.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
import logging
|
4 |
+
import pickle
|
5 |
+
from collections import Counter
|
6 |
+
import re
|
7 |
+
import fire
|
8 |
+
|
9 |
+
|
10 |
+
class Vocabulary(object):
|
11 |
+
"""Simple vocabulary wrapper."""
|
12 |
+
def __init__(self):
|
13 |
+
self.word2idx = {}
|
14 |
+
self.idx2word = {}
|
15 |
+
self.idx = 0
|
16 |
+
|
17 |
+
def add_word(self, word):
|
18 |
+
if not word in self.word2idx:
|
19 |
+
self.word2idx[word] = self.idx
|
20 |
+
self.idx2word[self.idx] = word
|
21 |
+
self.idx += 1
|
22 |
+
|
23 |
+
def __call__(self, word):
|
24 |
+
if not word in self.word2idx:
|
25 |
+
return self.word2idx["<unk>"]
|
26 |
+
return self.word2idx[word]
|
27 |
+
|
28 |
+
def __getitem__(self, word_id):
|
29 |
+
return self.idx2word[word_id]
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.word2idx)
|
33 |
+
|
34 |
+
|
35 |
+
def build_vocab(input_json: str,
|
36 |
+
threshold: int,
|
37 |
+
keep_punctuation: bool,
|
38 |
+
host_address: str,
|
39 |
+
character_level: bool = False,
|
40 |
+
zh: bool = True ):
|
41 |
+
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
42 |
+
|
43 |
+
Args:
|
44 |
+
input_json(string): Preprossessed json file. Structure like this:
|
45 |
+
{
|
46 |
+
'audios': [
|
47 |
+
{
|
48 |
+
'audio_id': 'xxx',
|
49 |
+
'captions': [
|
50 |
+
{
|
51 |
+
'caption': 'xxx',
|
52 |
+
'cap_id': 'xxx'
|
53 |
+
}
|
54 |
+
]
|
55 |
+
},
|
56 |
+
...
|
57 |
+
]
|
58 |
+
}
|
59 |
+
threshold (int): Threshold to drop all words with counts < threshold
|
60 |
+
keep_punctuation (bool): Includes or excludes punctuation.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
vocab (Vocab): Object with the processed vocabulary
|
64 |
+
"""
|
65 |
+
data = json.load(open(input_json, "r"))["audios"]
|
66 |
+
counter = Counter()
|
67 |
+
pretokenized = "tokens" in data[0]["captions"][0]
|
68 |
+
|
69 |
+
if zh:
|
70 |
+
from nltk.parse.corenlp import CoreNLPParser
|
71 |
+
from zhon.hanzi import punctuation
|
72 |
+
if not pretokenized:
|
73 |
+
parser = CoreNLPParser(host_address)
|
74 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
75 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
76 |
+
if pretokenized:
|
77 |
+
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
78 |
+
else:
|
79 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
80 |
+
# Remove all punctuations
|
81 |
+
if not keep_punctuation:
|
82 |
+
caption = re.sub("[{}]".format(punctuation), "", caption)
|
83 |
+
if character_level:
|
84 |
+
tokens = list(caption)
|
85 |
+
else:
|
86 |
+
tokens = list(parser.tokenize(caption))
|
87 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
88 |
+
counter.update(tokens)
|
89 |
+
else:
|
90 |
+
if pretokenized:
|
91 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
92 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
93 |
+
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
94 |
+
counter.update(tokens)
|
95 |
+
else:
|
96 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
97 |
+
captions = {}
|
98 |
+
for audio_idx in range(len(data)):
|
99 |
+
audio_id = data[audio_idx]["audio_id"]
|
100 |
+
captions[audio_id] = []
|
101 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
102 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
103 |
+
captions[audio_id].append({
|
104 |
+
"audio_id": audio_id,
|
105 |
+
"id": cap_idx,
|
106 |
+
"caption": caption
|
107 |
+
})
|
108 |
+
tokenizer = PTBTokenizer()
|
109 |
+
captions = tokenizer.tokenize(captions)
|
110 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
111 |
+
audio_id = data[audio_idx]["audio_id"]
|
112 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
113 |
+
tokens = captions[audio_id][cap_idx]
|
114 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
115 |
+
counter.update(tokens.split(" "))
|
116 |
+
|
117 |
+
if not pretokenized:
|
118 |
+
json.dump({ "audios": data }, open(input_json, "w"), indent=4, ensure_ascii=not zh)
|
119 |
+
words = [word for word, cnt in counter.items() if cnt >= threshold]
|
120 |
+
|
121 |
+
# Create a vocab wrapper and add some special tokens.
|
122 |
+
vocab = Vocabulary()
|
123 |
+
vocab.add_word("<pad>")
|
124 |
+
vocab.add_word("<start>")
|
125 |
+
vocab.add_word("<end>")
|
126 |
+
vocab.add_word("<unk>")
|
127 |
+
|
128 |
+
# Add the words to the vocabulary.
|
129 |
+
for word in words:
|
130 |
+
vocab.add_word(word)
|
131 |
+
return vocab
|
132 |
+
|
133 |
+
|
134 |
+
def process(input_json: str,
|
135 |
+
output_file: str,
|
136 |
+
threshold: int = 1,
|
137 |
+
keep_punctuation: bool = False,
|
138 |
+
character_level: bool = False,
|
139 |
+
host_address: str = "http://localhost:9000",
|
140 |
+
zh: bool = False):
|
141 |
+
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
142 |
+
logging.basicConfig(level=logging.INFO, format=logfmt)
|
143 |
+
logging.info("Build Vocab")
|
144 |
+
vocabulary = build_vocab(
|
145 |
+
input_json=input_json, threshold=threshold, keep_punctuation=keep_punctuation,
|
146 |
+
host_address=host_address, character_level=character_level, zh=zh)
|
147 |
+
pickle.dump(vocabulary, open(output_file, "wb"))
|
148 |
+
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
|
149 |
+
logging.info("Saved vocab to '{}'".format(output_file))
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == '__main__':
|
153 |
+
fire.Fire(process)
|
audio_to_text/captioning/utils/build_vocab_ltp.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
import logging
|
4 |
+
import pickle
|
5 |
+
from collections import Counter
|
6 |
+
import re
|
7 |
+
import fire
|
8 |
+
|
9 |
+
class Vocabulary(object):
|
10 |
+
"""Simple vocabulary wrapper."""
|
11 |
+
def __init__(self):
|
12 |
+
self.word2idx = {}
|
13 |
+
self.idx2word = {}
|
14 |
+
self.idx = 0
|
15 |
+
|
16 |
+
def add_word(self, word):
|
17 |
+
if not word in self.word2idx:
|
18 |
+
self.word2idx[word] = self.idx
|
19 |
+
self.idx2word[self.idx] = word
|
20 |
+
self.idx += 1
|
21 |
+
|
22 |
+
def __call__(self, word):
|
23 |
+
if not word in self.word2idx:
|
24 |
+
return self.word2idx["<unk>"]
|
25 |
+
return self.word2idx[word]
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.word2idx)
|
29 |
+
|
30 |
+
def build_vocab(input_json: str,
|
31 |
+
output_json: str,
|
32 |
+
threshold: int,
|
33 |
+
keep_punctuation: bool,
|
34 |
+
character_level: bool = False,
|
35 |
+
zh: bool = True ):
|
36 |
+
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
37 |
+
|
38 |
+
Args:
|
39 |
+
input_json(string): Preprossessed json file. Structure like this:
|
40 |
+
{
|
41 |
+
'audios': [
|
42 |
+
{
|
43 |
+
'audio_id': 'xxx',
|
44 |
+
'captions': [
|
45 |
+
{
|
46 |
+
'caption': 'xxx',
|
47 |
+
'cap_id': 'xxx'
|
48 |
+
}
|
49 |
+
]
|
50 |
+
},
|
51 |
+
...
|
52 |
+
]
|
53 |
+
}
|
54 |
+
threshold (int): Threshold to drop all words with counts < threshold
|
55 |
+
keep_punctuation (bool): Includes or excludes punctuation.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
vocab (Vocab): Object with the processed vocabulary
|
59 |
+
"""
|
60 |
+
data = json.load(open(input_json, "r"))["audios"]
|
61 |
+
counter = Counter()
|
62 |
+
pretokenized = "tokens" in data[0]["captions"][0]
|
63 |
+
|
64 |
+
if zh:
|
65 |
+
from ltp import LTP
|
66 |
+
from zhon.hanzi import punctuation
|
67 |
+
if not pretokenized:
|
68 |
+
parser = LTP("base")
|
69 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
70 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
71 |
+
if pretokenized:
|
72 |
+
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
73 |
+
else:
|
74 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
75 |
+
if character_level:
|
76 |
+
tokens = list(caption)
|
77 |
+
else:
|
78 |
+
tokens, _ = parser.seg([caption])
|
79 |
+
tokens = tokens[0]
|
80 |
+
# Remove all punctuations
|
81 |
+
if not keep_punctuation:
|
82 |
+
tokens = [token for token in tokens if token not in punctuation]
|
83 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
84 |
+
counter.update(tokens)
|
85 |
+
else:
|
86 |
+
if pretokenized:
|
87 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
88 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
89 |
+
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
90 |
+
counter.update(tokens)
|
91 |
+
else:
|
92 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
93 |
+
captions = {}
|
94 |
+
for audio_idx in range(len(data)):
|
95 |
+
audio_id = data[audio_idx]["audio_id"]
|
96 |
+
captions[audio_id] = []
|
97 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
98 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
99 |
+
captions[audio_id].append({
|
100 |
+
"audio_id": audio_id,
|
101 |
+
"id": cap_idx,
|
102 |
+
"caption": caption
|
103 |
+
})
|
104 |
+
tokenizer = PTBTokenizer()
|
105 |
+
captions = tokenizer.tokenize(captions)
|
106 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
107 |
+
audio_id = data[audio_idx]["audio_id"]
|
108 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
109 |
+
tokens = captions[audio_id][cap_idx]
|
110 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
111 |
+
counter.update(tokens.split(" "))
|
112 |
+
|
113 |
+
if not pretokenized:
|
114 |
+
if output_json is None:
|
115 |
+
output_json = input_json
|
116 |
+
json.dump({ "audios": data }, open(output_json, "w"), indent=4, ensure_ascii=not zh)
|
117 |
+
words = [word for word, cnt in counter.items() if cnt >= threshold]
|
118 |
+
|
119 |
+
# Create a vocab wrapper and add some special tokens.
|
120 |
+
vocab = Vocabulary()
|
121 |
+
vocab.add_word("<pad>")
|
122 |
+
vocab.add_word("<start>")
|
123 |
+
vocab.add_word("<end>")
|
124 |
+
vocab.add_word("<unk>")
|
125 |
+
|
126 |
+
# Add the words to the vocabulary.
|
127 |
+
for word in words:
|
128 |
+
vocab.add_word(word)
|
129 |
+
return vocab
|
130 |
+
|
131 |
+
def process(input_json: str,
|
132 |
+
output_file: str,
|
133 |
+
output_json: str = None,
|
134 |
+
threshold: int = 1,
|
135 |
+
keep_punctuation: bool = False,
|
136 |
+
character_level: bool = False,
|
137 |
+
zh: bool = True):
|
138 |
+
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
139 |
+
logging.basicConfig(level=logging.INFO, format=logfmt)
|
140 |
+
logging.info("Build Vocab")
|
141 |
+
vocabulary = build_vocab(
|
142 |
+
input_json=input_json, output_json=output_json, threshold=threshold,
|
143 |
+
keep_punctuation=keep_punctuation, character_level=character_level, zh=zh)
|
144 |
+
pickle.dump(vocabulary, open(output_file, "wb"))
|
145 |
+
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
|
146 |
+
logging.info("Saved vocab to '{}'".format(output_file))
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
fire.Fire(process)
|
audio_to_text/captioning/utils/build_vocab_spacy.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
import logging
|
4 |
+
import pickle
|
5 |
+
from collections import Counter
|
6 |
+
import re
|
7 |
+
import fire
|
8 |
+
|
9 |
+
class Vocabulary(object):
|
10 |
+
"""Simple vocabulary wrapper."""
|
11 |
+
def __init__(self):
|
12 |
+
self.word2idx = {}
|
13 |
+
self.idx2word = {}
|
14 |
+
self.idx = 0
|
15 |
+
|
16 |
+
def add_word(self, word):
|
17 |
+
if not word in self.word2idx:
|
18 |
+
self.word2idx[word] = self.idx
|
19 |
+
self.idx2word[self.idx] = word
|
20 |
+
self.idx += 1
|
21 |
+
|
22 |
+
def __call__(self, word):
|
23 |
+
if not word in self.word2idx:
|
24 |
+
return self.word2idx["<unk>"]
|
25 |
+
return self.word2idx[word]
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.word2idx)
|
29 |
+
|
30 |
+
|
31 |
+
def build_vocab(input_json: str,
|
32 |
+
output_json: str,
|
33 |
+
threshold: int,
|
34 |
+
keep_punctuation: bool,
|
35 |
+
host_address: str,
|
36 |
+
character_level: bool = False,
|
37 |
+
retokenize: bool = True,
|
38 |
+
zh: bool = True ):
|
39 |
+
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
40 |
+
|
41 |
+
Args:
|
42 |
+
input_json(string): Preprossessed json file. Structure like this:
|
43 |
+
{
|
44 |
+
'audios': [
|
45 |
+
{
|
46 |
+
'audio_id': 'xxx',
|
47 |
+
'captions': [
|
48 |
+
{
|
49 |
+
'caption': 'xxx',
|
50 |
+
'cap_id': 'xxx'
|
51 |
+
}
|
52 |
+
]
|
53 |
+
},
|
54 |
+
...
|
55 |
+
]
|
56 |
+
}
|
57 |
+
threshold (int): Threshold to drop all words with counts < threshold
|
58 |
+
keep_punctuation (bool): Includes or excludes punctuation.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
vocab (Vocab): Object with the processed vocabulary
|
62 |
+
"""
|
63 |
+
data = json.load(open(input_json, "r"))["audios"]
|
64 |
+
counter = Counter()
|
65 |
+
if retokenize:
|
66 |
+
pretokenized = False
|
67 |
+
else:
|
68 |
+
pretokenized = "tokens" in data[0]["captions"][0]
|
69 |
+
|
70 |
+
if zh:
|
71 |
+
from nltk.parse.corenlp import CoreNLPParser
|
72 |
+
from zhon.hanzi import punctuation
|
73 |
+
if not pretokenized:
|
74 |
+
parser = CoreNLPParser(host_address)
|
75 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
76 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
77 |
+
if pretokenized:
|
78 |
+
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
79 |
+
else:
|
80 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
81 |
+
# Remove all punctuations
|
82 |
+
if not keep_punctuation:
|
83 |
+
caption = re.sub("[{}]".format(punctuation), "", caption)
|
84 |
+
if character_level:
|
85 |
+
tokens = list(caption)
|
86 |
+
else:
|
87 |
+
tokens = list(parser.tokenize(caption))
|
88 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
89 |
+
counter.update(tokens)
|
90 |
+
else:
|
91 |
+
if pretokenized:
|
92 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
93 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
94 |
+
tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
|
95 |
+
counter.update(tokens)
|
96 |
+
else:
|
97 |
+
import spacy
|
98 |
+
tokenizer = spacy.load("en_core_web_sm", disable=["parser", "ner"])
|
99 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
100 |
+
captions = data[audio_idx]["captions"]
|
101 |
+
for cap_idx in range(len(captions)):
|
102 |
+
caption = captions[cap_idx]["caption"]
|
103 |
+
doc = tokenizer(caption)
|
104 |
+
tokens = " ".join([str(token).lower() for token in doc])
|
105 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
106 |
+
counter.update(tokens.split(" "))
|
107 |
+
|
108 |
+
if not pretokenized:
|
109 |
+
if output_json is None:
|
110 |
+
json.dump({ "audios": data }, open(input_json, "w"),
|
111 |
+
indent=4, ensure_ascii=not zh)
|
112 |
+
else:
|
113 |
+
json.dump({ "audios": data }, open(output_json, "w"),
|
114 |
+
indent=4, ensure_ascii=not zh)
|
115 |
+
|
116 |
+
words = [word for word, cnt in counter.items() if cnt >= threshold]
|
117 |
+
|
118 |
+
# Create a vocab wrapper and add some special tokens.
|
119 |
+
vocab = Vocabulary()
|
120 |
+
vocab.add_word("<pad>")
|
121 |
+
vocab.add_word("<start>")
|
122 |
+
vocab.add_word("<end>")
|
123 |
+
vocab.add_word("<unk>")
|
124 |
+
|
125 |
+
# Add the words to the vocabulary.
|
126 |
+
for word in words:
|
127 |
+
vocab.add_word(word)
|
128 |
+
return vocab
|
129 |
+
|
130 |
+
def process(input_json: str,
|
131 |
+
output_file: str,
|
132 |
+
output_json: str = None,
|
133 |
+
threshold: int = 1,
|
134 |
+
keep_punctuation: bool = False,
|
135 |
+
character_level: bool = False,
|
136 |
+
retokenize: bool = False,
|
137 |
+
host_address: str = "http://localhost:9000",
|
138 |
+
zh: bool = True):
|
139 |
+
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
140 |
+
logging.basicConfig(level=logging.INFO, format=logfmt)
|
141 |
+
logging.info("Build Vocab")
|
142 |
+
vocabulary = build_vocab(
|
143 |
+
input_json=input_json, output_json=output_json, threshold=threshold,
|
144 |
+
keep_punctuation=keep_punctuation, host_address=host_address,
|
145 |
+
character_level=character_level, retokenize=retokenize, zh=zh)
|
146 |
+
pickle.dump(vocabulary, open(output_file, "wb"))
|
147 |
+
logging.info("Total vocabulary size: {}".format(len(vocabulary)))
|
148 |
+
logging.info("Saved vocab to '{}'".format(output_file))
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
fire.Fire(process)
|
audio_to_text/captioning/utils/eval_round_robin.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import fire
|
6 |
+
|
7 |
+
|
8 |
+
def evaluate_annotation(key2refs, scorer):
|
9 |
+
if scorer.method() == "Bleu":
|
10 |
+
scores = np.array([ 0.0 for n in range(4) ])
|
11 |
+
else:
|
12 |
+
scores = 0
|
13 |
+
num_cap_per_audio = len(next(iter(key2refs.values())))
|
14 |
+
|
15 |
+
for i in range(num_cap_per_audio):
|
16 |
+
if i > 0:
|
17 |
+
for key in key2refs:
|
18 |
+
key2refs[key].insert(0, res[key][0])
|
19 |
+
res = { key: [refs.pop(),] for key, refs in key2refs.items() }
|
20 |
+
score, _ = scorer.compute_score(key2refs, res)
|
21 |
+
|
22 |
+
if scorer.method() == "Bleu":
|
23 |
+
scores += np.array(score)
|
24 |
+
else:
|
25 |
+
scores += score
|
26 |
+
|
27 |
+
score = scores / num_cap_per_audio
|
28 |
+
return score
|
29 |
+
|
30 |
+
def evaluate_prediction(key2pred, key2refs, scorer):
|
31 |
+
if scorer.method() == "Bleu":
|
32 |
+
scores = np.array([ 0.0 for n in range(4) ])
|
33 |
+
else:
|
34 |
+
scores = 0
|
35 |
+
num_cap_per_audio = len(next(iter(key2refs.values())))
|
36 |
+
|
37 |
+
for i in range(num_cap_per_audio):
|
38 |
+
key2refs_i = {}
|
39 |
+
for key, refs in key2refs.items():
|
40 |
+
key2refs_i[key] = refs[:i] + refs[i+1:]
|
41 |
+
score, _ = scorer.compute_score(key2refs_i, key2pred)
|
42 |
+
|
43 |
+
if scorer.method() == "Bleu":
|
44 |
+
scores += np.array(score)
|
45 |
+
else:
|
46 |
+
scores += score
|
47 |
+
|
48 |
+
score = scores / num_cap_per_audio
|
49 |
+
return score
|
50 |
+
|
51 |
+
|
52 |
+
class Evaluator(object):
|
53 |
+
|
54 |
+
def eval_annotation(self, annotation, output):
|
55 |
+
captions = json.load(open(annotation, "r"))["audios"]
|
56 |
+
|
57 |
+
key2refs = {}
|
58 |
+
for audio_idx in range(len(captions)):
|
59 |
+
audio_id = captions[audio_idx]["audio_id"]
|
60 |
+
key2refs[audio_id] = []
|
61 |
+
for caption in captions[audio_idx]["captions"]:
|
62 |
+
key2refs[audio_id].append(caption["caption"])
|
63 |
+
|
64 |
+
from fense.fense import Fense
|
65 |
+
scores = {}
|
66 |
+
scorer = Fense()
|
67 |
+
scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
|
68 |
+
|
69 |
+
refs4eval = {}
|
70 |
+
for key, refs in key2refs.items():
|
71 |
+
refs4eval[key] = []
|
72 |
+
for idx, ref in enumerate(refs):
|
73 |
+
refs4eval[key].append({
|
74 |
+
"audio_id": key,
|
75 |
+
"id": idx,
|
76 |
+
"caption": ref
|
77 |
+
})
|
78 |
+
|
79 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
80 |
+
|
81 |
+
tokenizer = PTBTokenizer()
|
82 |
+
key2refs = tokenizer.tokenize(refs4eval)
|
83 |
+
|
84 |
+
|
85 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
86 |
+
from pycocoevalcap.cider.cider import Cider
|
87 |
+
from pycocoevalcap.rouge.rouge import Rouge
|
88 |
+
from pycocoevalcap.meteor.meteor import Meteor
|
89 |
+
from pycocoevalcap.spice.spice import Spice
|
90 |
+
|
91 |
+
|
92 |
+
scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
|
93 |
+
for scorer in scorers:
|
94 |
+
scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
|
95 |
+
|
96 |
+
spider = 0
|
97 |
+
with open(output, "w") as f:
|
98 |
+
for name, score in scores.items():
|
99 |
+
if name == "Bleu":
|
100 |
+
for n in range(4):
|
101 |
+
f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
|
102 |
+
else:
|
103 |
+
f.write("{}: {:6.3f}\n".format(name, score))
|
104 |
+
if name in ["CIDEr", "SPICE"]:
|
105 |
+
spider += score
|
106 |
+
f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
|
107 |
+
|
108 |
+
def eval_prediction(self, prediction, annotation, output):
|
109 |
+
ref_captions = json.load(open(annotation, "r"))["audios"]
|
110 |
+
|
111 |
+
key2refs = {}
|
112 |
+
for audio_idx in range(len(ref_captions)):
|
113 |
+
audio_id = ref_captions[audio_idx]["audio_id"]
|
114 |
+
key2refs[audio_id] = []
|
115 |
+
for caption in ref_captions[audio_idx]["captions"]:
|
116 |
+
key2refs[audio_id].append(caption["caption"])
|
117 |
+
|
118 |
+
pred_captions = json.load(open(prediction, "r"))["predictions"]
|
119 |
+
|
120 |
+
key2pred = {}
|
121 |
+
for audio_idx in range(len(pred_captions)):
|
122 |
+
item = pred_captions[audio_idx]
|
123 |
+
audio_id = item["filename"]
|
124 |
+
key2pred[audio_id] = [item["tokens"]]
|
125 |
+
|
126 |
+
from fense.fense import Fense
|
127 |
+
scores = {}
|
128 |
+
scorer = Fense()
|
129 |
+
scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
|
130 |
+
|
131 |
+
refs4eval = {}
|
132 |
+
for key, refs in key2refs.items():
|
133 |
+
refs4eval[key] = []
|
134 |
+
for idx, ref in enumerate(refs):
|
135 |
+
refs4eval[key].append({
|
136 |
+
"audio_id": key,
|
137 |
+
"id": idx,
|
138 |
+
"caption": ref
|
139 |
+
})
|
140 |
+
|
141 |
+
preds4eval = {}
|
142 |
+
for key, preds in key2pred.items():
|
143 |
+
preds4eval[key] = []
|
144 |
+
for idx, pred in enumerate(preds):
|
145 |
+
preds4eval[key].append({
|
146 |
+
"audio_id": key,
|
147 |
+
"id": idx,
|
148 |
+
"caption": pred
|
149 |
+
})
|
150 |
+
|
151 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
152 |
+
|
153 |
+
tokenizer = PTBTokenizer()
|
154 |
+
key2refs = tokenizer.tokenize(refs4eval)
|
155 |
+
key2pred = tokenizer.tokenize(preds4eval)
|
156 |
+
|
157 |
+
|
158 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
159 |
+
from pycocoevalcap.cider.cider import Cider
|
160 |
+
from pycocoevalcap.rouge.rouge import Rouge
|
161 |
+
from pycocoevalcap.meteor.meteor import Meteor
|
162 |
+
from pycocoevalcap.spice.spice import Spice
|
163 |
+
|
164 |
+
scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
|
165 |
+
for scorer in scorers:
|
166 |
+
scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
|
167 |
+
|
168 |
+
spider = 0
|
169 |
+
with open(output, "w") as f:
|
170 |
+
for name, score in scores.items():
|
171 |
+
if name == "Bleu":
|
172 |
+
for n in range(4):
|
173 |
+
f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
|
174 |
+
else:
|
175 |
+
f.write("{}: {:6.3f}\n".format(name, score))
|
176 |
+
if name in ["CIDEr", "SPICE"]:
|
177 |
+
spider += score
|
178 |
+
f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == "__main__":
|
182 |
+
fire.Fire(Evaluator)
|
audio_to_text/captioning/utils/fasttext/create_word_embedding.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
#!/usr/bin/env python3
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from gensim.models import FastText
|
8 |
+
from tqdm import tqdm
|
9 |
+
import fire
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import os
|
13 |
+
sys.path.append(os.getcwd())
|
14 |
+
from utils.build_vocab import Vocabulary
|
15 |
+
|
16 |
+
def create_embedding(caption_file: str,
|
17 |
+
vocab_file: str,
|
18 |
+
embed_size: int,
|
19 |
+
output: str,
|
20 |
+
**fasttext_kwargs):
|
21 |
+
caption_df = pd.read_json(caption_file)
|
22 |
+
caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
|
23 |
+
|
24 |
+
sentences = list(caption_df["tokens"].values)
|
25 |
+
vocabulary = torch.load(vocab_file, map_location="cpu")
|
26 |
+
|
27 |
+
epochs = fasttext_kwargs.get("epochs", 10)
|
28 |
+
model = FastText(size=embed_size, min_count=1, **fasttext_kwargs)
|
29 |
+
model.build_vocab(sentences=sentences)
|
30 |
+
model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
|
31 |
+
|
32 |
+
word_embeddings = np.zeros((len(vocabulary), embed_size))
|
33 |
+
|
34 |
+
with tqdm(total=len(vocabulary), ascii=True) as pbar:
|
35 |
+
for word, idx in vocabulary.word2idx.items():
|
36 |
+
if word == "<pad>" or word == "<unk>":
|
37 |
+
continue
|
38 |
+
word_embeddings[idx] = model.wv[word]
|
39 |
+
pbar.update()
|
40 |
+
|
41 |
+
np.save(output, word_embeddings)
|
42 |
+
|
43 |
+
print("Finish writing fasttext embeddings to " + output)
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
fire.Fire(create_embedding)
|
48 |
+
|
49 |
+
|
50 |
+
|
audio_to_text/captioning/utils/lr_scheduler.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler):
|
6 |
+
|
7 |
+
def __init__(self, optimizer, total_iters, final_lrs,
|
8 |
+
warmup_iters=3000, last_epoch=-1, verbose=False):
|
9 |
+
self.total_iters = total_iters
|
10 |
+
self.final_lrs = final_lrs
|
11 |
+
if not isinstance(self.final_lrs, list) and not isinstance(
|
12 |
+
self.final_lrs, tuple):
|
13 |
+
self.final_lrs = [self.final_lrs] * len(optimizer.param_groups)
|
14 |
+
self.warmup_iters = warmup_iters
|
15 |
+
self.bases = [0.0,] * len(optimizer.param_groups)
|
16 |
+
super().__init__(optimizer, last_epoch, verbose)
|
17 |
+
for i, (base_lr, final_lr) in enumerate(zip(self.base_lrs, self.final_lrs)):
|
18 |
+
base = (final_lr / base_lr) ** (1 / (
|
19 |
+
self.total_iters - self.warmup_iters))
|
20 |
+
self.bases[i] = base
|
21 |
+
|
22 |
+
def _get_closed_form_lr(self):
|
23 |
+
warmup_coeff = 1.0
|
24 |
+
current_iter = self._step_count
|
25 |
+
if current_iter < self.warmup_iters:
|
26 |
+
warmup_coeff = current_iter / self.warmup_iters
|
27 |
+
current_lrs = []
|
28 |
+
# if not self.linear_warmup:
|
29 |
+
# for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, self.bases):
|
30 |
+
# # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
|
31 |
+
# current_lr = warmup_coeff * base_lr * (base ** (current_iter - self.warmup_iters))
|
32 |
+
# current_lrs.append(current_lr)
|
33 |
+
# else:
|
34 |
+
for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs,
|
35 |
+
self.bases):
|
36 |
+
if current_iter <= self.warmup_iters:
|
37 |
+
current_lr = warmup_coeff * base_lr
|
38 |
+
else:
|
39 |
+
# current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
|
40 |
+
current_lr = base_lr * (base ** (current_iter - self.warmup_iters))
|
41 |
+
current_lrs.append(current_lr)
|
42 |
+
return current_lrs
|
43 |
+
|
44 |
+
def get_lr(self):
|
45 |
+
return self._get_closed_form_lr()
|
46 |
+
|
47 |
+
|
48 |
+
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
|
49 |
+
|
50 |
+
def __init__(self, optimizer, model_size=512, factor=1, warmup_iters=3000,
|
51 |
+
last_epoch=-1, verbose=False):
|
52 |
+
self.model_size = model_size
|
53 |
+
self.warmup_iters = warmup_iters
|
54 |
+
# self.factors = [group["lr"] / (self.model_size ** (-0.5) * self.warmup_iters ** (-0.5)) for group in optimizer.param_groups]
|
55 |
+
self.factor = factor
|
56 |
+
super().__init__(optimizer, last_epoch, verbose)
|
57 |
+
|
58 |
+
def _get_closed_form_lr(self):
|
59 |
+
current_iter = self._step_count
|
60 |
+
current_lrs = []
|
61 |
+
for _ in self.base_lrs:
|
62 |
+
current_lr = self.factor * \
|
63 |
+
(self.model_size ** (-0.5) * min(current_iter ** (-0.5),
|
64 |
+
current_iter * self.warmup_iters ** (-1.5)))
|
65 |
+
current_lrs.append(current_lr)
|
66 |
+
return current_lrs
|
67 |
+
|
68 |
+
def get_lr(self):
|
69 |
+
return self._get_closed_form_lr()
|
70 |
+
|
71 |
+
|
72 |
+
class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler):
|
73 |
+
|
74 |
+
def __init__(self, optimizer, total_iters, warmup_iters,
|
75 |
+
num_cycles=0.5, last_epoch=-1, verbose=False):
|
76 |
+
self.total_iters = total_iters
|
77 |
+
self.warmup_iters = warmup_iters
|
78 |
+
self.num_cycles = num_cycles
|
79 |
+
super().__init__(optimizer, last_epoch, verbose)
|
80 |
+
|
81 |
+
def lr_lambda(self, iteration):
|
82 |
+
if iteration < self.warmup_iters:
|
83 |
+
return float(iteration) / float(max(1, self.warmup_iters))
|
84 |
+
progress = float(iteration - self.warmup_iters) / float(max(1,
|
85 |
+
self.total_iters - self.warmup_iters))
|
86 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(
|
87 |
+
self.num_cycles) * 2.0 * progress)))
|
88 |
+
|
89 |
+
def _get_closed_form_lr(self):
|
90 |
+
current_iter = self._step_count
|
91 |
+
current_lrs = []
|
92 |
+
for base_lr in self.base_lrs:
|
93 |
+
current_lr = base_lr * self.lr_lambda(current_iter)
|
94 |
+
current_lrs.append(current_lr)
|
95 |
+
return current_lrs
|
96 |
+
|
97 |
+
def get_lr(self):
|
98 |
+
return self._get_closed_form_lr()
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
model = torch.nn.Linear(10, 5)
|
103 |
+
optimizer = torch.optim.Adam(model.parameters(), 5e-4)
|
104 |
+
epochs = 25
|
105 |
+
iters = 600
|
106 |
+
scheduler = CosineWithWarmup(optimizer, 600 * 25, 600 * 5,)
|
107 |
+
# scheduler = ExponentialDecayScheduler(optimizer, 600 * 25, 5e-7, 600 * 5)
|
108 |
+
criterion = torch.nn.MSELoss()
|
109 |
+
lrs = []
|
110 |
+
for epoch in range(1, epochs + 1):
|
111 |
+
for iteration in range(1, iters + 1):
|
112 |
+
optimizer.zero_grad()
|
113 |
+
x = torch.randn(4, 10)
|
114 |
+
y = torch.randn(4, 5)
|
115 |
+
loss = criterion(model(x), y)
|
116 |
+
loss.backward()
|
117 |
+
optimizer.step()
|
118 |
+
scheduler.step()
|
119 |
+
# print(f"lr: {scheduler.get_last_lr()}")
|
120 |
+
# lrs.append(scheduler.get_last_lr())
|
121 |
+
lrs.append(optimizer.param_groups[0]["lr"])
|
122 |
+
import matplotlib.pyplot as plt
|
123 |
+
plt.plot(list(range(1, len(lrs) + 1)), lrs, '-o', markersize=1)
|
124 |
+
# plt.legend(loc="best")
|
125 |
+
plt.xlabel("Iteration")
|
126 |
+
plt.ylabel("LR")
|
127 |
+
|
128 |
+
plt.savefig("lr_curve.png", dpi=100)
|
audio_to_text/captioning/utils/model_eval_diff.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import fire
|
9 |
+
|
10 |
+
sys.path.append(os.getcwd())
|
11 |
+
|
12 |
+
|
13 |
+
def coco_score(refs, pred, scorer):
|
14 |
+
if scorer.method() == "Bleu":
|
15 |
+
scores = np.array([ 0.0 for n in range(4) ])
|
16 |
+
else:
|
17 |
+
scores = 0
|
18 |
+
num_cap_per_audio = len(refs[list(refs.keys())[0]])
|
19 |
+
|
20 |
+
for i in range(num_cap_per_audio):
|
21 |
+
if i > 0:
|
22 |
+
for key in refs:
|
23 |
+
refs[key].insert(0, res[key][0])
|
24 |
+
res = {key: [refs[key].pop(),] for key in refs}
|
25 |
+
score, _ = scorer.compute_score(refs, pred)
|
26 |
+
|
27 |
+
if scorer.method() == "Bleu":
|
28 |
+
scores += np.array(score)
|
29 |
+
else:
|
30 |
+
scores += score
|
31 |
+
|
32 |
+
score = scores / num_cap_per_audio
|
33 |
+
|
34 |
+
for key in refs:
|
35 |
+
refs[key].insert(0, res[key][0])
|
36 |
+
score_allref, _ = scorer.compute_score(refs, pred)
|
37 |
+
diff = score_allref - score
|
38 |
+
return diff
|
39 |
+
|
40 |
+
def embedding_score(refs, pred, scorer):
|
41 |
+
|
42 |
+
num_cap_per_audio = len(refs[list(refs.keys())[0]])
|
43 |
+
scores = 0
|
44 |
+
|
45 |
+
for i in range(num_cap_per_audio):
|
46 |
+
res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
|
47 |
+
refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
|
48 |
+
score, _ = scorer.compute_score(refs_i, pred)
|
49 |
+
|
50 |
+
scores += score
|
51 |
+
|
52 |
+
score = scores / num_cap_per_audio
|
53 |
+
|
54 |
+
score_allref, _ = scorer.compute_score(refs, pred)
|
55 |
+
diff = score_allref - score
|
56 |
+
return diff
|
57 |
+
|
58 |
+
def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False):
|
59 |
+
output_df = pd.read_json(output_file)
|
60 |
+
output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0])
|
61 |
+
pred = output_df.groupby("key")["tokens"].apply(list).to_dict()
|
62 |
+
|
63 |
+
label_df = pd.read_json(eval_caption_file)
|
64 |
+
if zh:
|
65 |
+
refs = label_df.groupby("key")["tokens"].apply(list).to_dict()
|
66 |
+
else:
|
67 |
+
refs = label_df.groupby("key")["caption"].apply(list).to_dict()
|
68 |
+
|
69 |
+
from pycocoevalcap.bleu.bleu import Bleu
|
70 |
+
from pycocoevalcap.cider.cider import Cider
|
71 |
+
from pycocoevalcap.rouge.rouge import Rouge
|
72 |
+
|
73 |
+
scorer = Bleu(zh=zh)
|
74 |
+
bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer)
|
75 |
+
scorer = Cider(zh=zh)
|
76 |
+
cider_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
77 |
+
scorer = Rouge(zh=zh)
|
78 |
+
rouge_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
79 |
+
|
80 |
+
if not zh:
|
81 |
+
from pycocoevalcap.meteor.meteor import Meteor
|
82 |
+
scorer = Meteor()
|
83 |
+
meteor_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
84 |
+
|
85 |
+
from pycocoevalcap.spice.spice import Spice
|
86 |
+
scorer = Spice()
|
87 |
+
spice_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
88 |
+
|
89 |
+
# from audiocaptioneval.sentbert.sentencebert import SentenceBert
|
90 |
+
# scorer = SentenceBert(zh=zh)
|
91 |
+
# with open(eval_embedding_file, "rb") as f:
|
92 |
+
# ref_embeddings = pickle.load(f)
|
93 |
+
|
94 |
+
# sent_bert = embedding_score(ref_embeddings, pred, scorer)
|
95 |
+
|
96 |
+
with open(output, "w") as f:
|
97 |
+
f.write("Diff:\n")
|
98 |
+
for n in range(4):
|
99 |
+
f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n]))
|
100 |
+
f.write("CIDEr: {:6.3f}\n".format(cider_score))
|
101 |
+
f.write("ROUGE: {:6.3f}\n".format(rouge_score))
|
102 |
+
if not zh:
|
103 |
+
f.write("Meteor: {:6.3f}\n".format(meteor_score))
|
104 |
+
f.write("SPICE: {:6.3f}\n".format(spice_score))
|
105 |
+
# f.write("SentenceBert: {:6.3f}\n".format(sent_bert))
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
fire.Fire(main)
|
audio_to_text/captioning/utils/predict_nn.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from h5py import File
|
7 |
+
import sklearn.metrics
|
8 |
+
|
9 |
+
random.seed(1)
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument("train_feature", type=str)
|
13 |
+
parser.add_argument("train_corpus", type=str)
|
14 |
+
parser.add_argument("pred_feature", type=str)
|
15 |
+
parser.add_argument("output_json", type=str)
|
16 |
+
|
17 |
+
args = parser.parse_args()
|
18 |
+
train_embs = []
|
19 |
+
train_idx_to_audioid = []
|
20 |
+
with File(args.train_feature, "r") as store:
|
21 |
+
for audio_id, embedding in tqdm(store.items(), ascii=True):
|
22 |
+
train_embs.append(embedding[()])
|
23 |
+
train_idx_to_audioid.append(audio_id)
|
24 |
+
|
25 |
+
train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
|
26 |
+
train_audioid_to_tokens = {}
|
27 |
+
for item in train_annotation:
|
28 |
+
audio_id = item["audio_id"]
|
29 |
+
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
|
30 |
+
train_embs = np.stack(train_embs)
|
31 |
+
|
32 |
+
|
33 |
+
pred_data = []
|
34 |
+
pred_embs = []
|
35 |
+
pred_idx_to_audioids = []
|
36 |
+
with File(args.pred_feature, "r") as store:
|
37 |
+
for audio_id, embedding in tqdm(store.items(), ascii=True):
|
38 |
+
pred_embs.append(embedding[()])
|
39 |
+
pred_idx_to_audioids.append(audio_id)
|
40 |
+
pred_embs = np.stack(pred_embs)
|
41 |
+
|
42 |
+
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
|
43 |
+
for idx, audio_id in enumerate(pred_idx_to_audioids):
|
44 |
+
train_idx = similarity[idx].argmax()
|
45 |
+
pred_data.append({
|
46 |
+
"filename": audio_id,
|
47 |
+
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
|
48 |
+
})
|
49 |
+
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)
|
audio_to_text/captioning/utils/remove_optimizer.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def main(checkpoint):
|
6 |
+
state_dict = torch.load(checkpoint, map_location="cpu")
|
7 |
+
if "optimizer" in state_dict:
|
8 |
+
del state_dict["optimizer"]
|
9 |
+
if "lr_scheduler" in state_dict:
|
10 |
+
del state_dict["lr_scheduler"]
|
11 |
+
torch.save(state_dict, checkpoint)
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("checkpoint", type=str)
|
17 |
+
args = parser.parse_args()
|
18 |
+
main(args.checkpoint)
|
audio_to_text/captioning/utils/report_results.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
parser = argparse.ArgumentParser()
|
6 |
+
parser.add_argument("--input", help="input filename", type=str, nargs="+")
|
7 |
+
parser.add_argument("--output", help="output result file", default=None)
|
8 |
+
|
9 |
+
args = parser.parse_args()
|
10 |
+
|
11 |
+
|
12 |
+
scores = {}
|
13 |
+
for path in args.input:
|
14 |
+
with open(path, "r") as reader:
|
15 |
+
for line in reader.readlines():
|
16 |
+
metric, score = line.strip().split(": ")
|
17 |
+
score = float(score)
|
18 |
+
if metric not in scores:
|
19 |
+
scores[metric] = []
|
20 |
+
scores[metric].append(score)
|
21 |
+
|
22 |
+
if len(scores) == 0:
|
23 |
+
print("No experiment directory found, wrong path?")
|
24 |
+
exit(1)
|
25 |
+
|
26 |
+
with open(args.output, "w") as writer:
|
27 |
+
print("Average results: ", file=writer)
|
28 |
+
for metric, score in scores.items():
|
29 |
+
score = np.array(score)
|
30 |
+
mean = np.mean(score)
|
31 |
+
std = np.std(score)
|
32 |
+
print(f"{metric}: {mean:.3f} (±{std:.3f})", file=writer)
|
33 |
+
print("", file=writer)
|
34 |
+
print("Best results: ", file=writer)
|
35 |
+
for metric, score in scores.items():
|
36 |
+
score = np.max(score)
|
37 |
+
print(f"{metric}: {score:.3f}", file=writer)
|
audio_to_text/captioning/utils/tokenize_caption.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from tqdm import tqdm
|
3 |
+
import re
|
4 |
+
import fire
|
5 |
+
|
6 |
+
|
7 |
+
def tokenize_caption(input_json: str,
|
8 |
+
keep_punctuation: bool = False,
|
9 |
+
host_address: str = None,
|
10 |
+
character_level: bool = False,
|
11 |
+
zh: bool = True,
|
12 |
+
output_json: str = None):
|
13 |
+
"""Build vocabulary from csv file with a given threshold to drop all counts < threshold
|
14 |
+
|
15 |
+
Args:
|
16 |
+
input_json(string): Preprossessed json file. Structure like this:
|
17 |
+
{
|
18 |
+
'audios': [
|
19 |
+
{
|
20 |
+
'audio_id': 'xxx',
|
21 |
+
'captions': [
|
22 |
+
{
|
23 |
+
'caption': 'xxx',
|
24 |
+
'cap_id': 'xxx'
|
25 |
+
}
|
26 |
+
]
|
27 |
+
},
|
28 |
+
...
|
29 |
+
]
|
30 |
+
}
|
31 |
+
threshold (int): Threshold to drop all words with counts < threshold
|
32 |
+
keep_punctuation (bool): Includes or excludes punctuation.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
vocab (Vocab): Object with the processed vocabulary
|
36 |
+
"""
|
37 |
+
data = json.load(open(input_json, "r"))["audios"]
|
38 |
+
|
39 |
+
if zh:
|
40 |
+
from nltk.parse.corenlp import CoreNLPParser
|
41 |
+
from zhon.hanzi import punctuation
|
42 |
+
parser = CoreNLPParser(host_address)
|
43 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
44 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
45 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
46 |
+
# Remove all punctuations
|
47 |
+
if not keep_punctuation:
|
48 |
+
caption = re.sub("[{}]".format(punctuation), "", caption)
|
49 |
+
if character_level:
|
50 |
+
tokens = list(caption)
|
51 |
+
else:
|
52 |
+
tokens = list(parser.tokenize(caption))
|
53 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
|
54 |
+
else:
|
55 |
+
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
56 |
+
captions = {}
|
57 |
+
for audio_idx in range(len(data)):
|
58 |
+
audio_id = data[audio_idx]["audio_id"]
|
59 |
+
captions[audio_id] = []
|
60 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
61 |
+
caption = data[audio_idx]["captions"][cap_idx]["caption"]
|
62 |
+
captions[audio_id].append({
|
63 |
+
"audio_id": audio_id,
|
64 |
+
"id": cap_idx,
|
65 |
+
"caption": caption
|
66 |
+
})
|
67 |
+
tokenizer = PTBTokenizer()
|
68 |
+
captions = tokenizer.tokenize(captions)
|
69 |
+
for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
|
70 |
+
audio_id = data[audio_idx]["audio_id"]
|
71 |
+
for cap_idx in range(len(data[audio_idx]["captions"])):
|
72 |
+
tokens = captions[audio_id][cap_idx]
|
73 |
+
data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
|
74 |
+
|
75 |
+
if output_json:
|
76 |
+
json.dump(
|
77 |
+
{ "audios": data }, open(output_json, "w"),
|
78 |
+
indent=4, ensure_ascii=not zh)
|
79 |
+
else:
|
80 |
+
json.dump(
|
81 |
+
{ "audios": data }, open(input_json, "w"),
|
82 |
+
indent=4, ensure_ascii=not zh)
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == "__main__":
|
86 |
+
fire.Fire(tokenize_caption)
|
audio_to_text/captioning/utils/train_util.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
#!/usr/bin/env python3
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import logging
|
6 |
+
from typing import Callable, Dict, Union
|
7 |
+
import yaml
|
8 |
+
import torch
|
9 |
+
from torch.optim.swa_utils import AveragedModel as torch_average_model
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
from pprint import pformat
|
13 |
+
|
14 |
+
|
15 |
+
def load_dict_from_csv(csv, cols):
|
16 |
+
df = pd.read_csv(csv, sep="\t")
|
17 |
+
output = dict(zip(df[cols[0]], df[cols[1]]))
|
18 |
+
return output
|
19 |
+
|
20 |
+
|
21 |
+
def init_logger(filename, level="INFO"):
|
22 |
+
formatter = logging.Formatter(
|
23 |
+
"[ %(levelname)s : %(asctime)s ] - %(message)s")
|
24 |
+
logger = logging.getLogger(__name__ + "." + filename)
|
25 |
+
logger.setLevel(getattr(logging, level))
|
26 |
+
# Log results to std
|
27 |
+
# stdhandler = logging.StreamHandler(sys.stdout)
|
28 |
+
# stdhandler.setFormatter(formatter)
|
29 |
+
# Dump log to file
|
30 |
+
filehandler = logging.FileHandler(filename)
|
31 |
+
filehandler.setFormatter(formatter)
|
32 |
+
logger.addHandler(filehandler)
|
33 |
+
# logger.addHandler(stdhandler)
|
34 |
+
return logger
|
35 |
+
|
36 |
+
|
37 |
+
def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
|
38 |
+
obj_args = config["args"].copy()
|
39 |
+
obj_args.update(kwargs)
|
40 |
+
return getattr(module, config["type"])(**obj_args)
|
41 |
+
|
42 |
+
|
43 |
+
def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
|
44 |
+
"""pprint_dict
|
45 |
+
|
46 |
+
:param outputfun: function to use, defaults to sys.stdout
|
47 |
+
:param in_dict: dict to print
|
48 |
+
"""
|
49 |
+
if formatter == 'yaml':
|
50 |
+
format_fun = yaml.dump
|
51 |
+
elif formatter == 'pretty':
|
52 |
+
format_fun = pformat
|
53 |
+
for line in format_fun(in_dict).split('\n'):
|
54 |
+
outputfun(line)
|
55 |
+
|
56 |
+
|
57 |
+
def merge_a_into_b(a, b):
|
58 |
+
# merge dict a into dict b. values in a will overwrite b.
|
59 |
+
for k, v in a.items():
|
60 |
+
if isinstance(v, dict) and k in b:
|
61 |
+
assert isinstance(
|
62 |
+
b[k], dict
|
63 |
+
), "Cannot inherit key '{}' from base!".format(k)
|
64 |
+
merge_a_into_b(v, b[k])
|
65 |
+
else:
|
66 |
+
b[k] = v
|
67 |
+
|
68 |
+
|
69 |
+
def load_config(config_file):
|
70 |
+
with open(config_file, "r") as reader:
|
71 |
+
config = yaml.load(reader, Loader=yaml.FullLoader)
|
72 |
+
if "inherit_from" in config:
|
73 |
+
base_config_file = config["inherit_from"]
|
74 |
+
base_config_file = os.path.join(
|
75 |
+
os.path.dirname(config_file), base_config_file
|
76 |
+
)
|
77 |
+
assert not os.path.samefile(config_file, base_config_file), \
|
78 |
+
"inherit from itself"
|
79 |
+
base_config = load_config(base_config_file)
|
80 |
+
del config["inherit_from"]
|
81 |
+
merge_a_into_b(config, base_config)
|
82 |
+
return base_config
|
83 |
+
return config
|
84 |
+
|
85 |
+
|
86 |
+
def parse_config_or_kwargs(config_file, **kwargs):
|
87 |
+
yaml_config = load_config(config_file)
|
88 |
+
# passed kwargs will override yaml config
|
89 |
+
args = dict(yaml_config, **kwargs)
|
90 |
+
return args
|
91 |
+
|
92 |
+
|
93 |
+
def store_yaml(config, config_file):
|
94 |
+
with open(config_file, "w") as con_writer:
|
95 |
+
yaml.dump(config, con_writer, indent=4, default_flow_style=False)
|
96 |
+
|
97 |
+
|
98 |
+
class MetricImprover:
|
99 |
+
|
100 |
+
def __init__(self, mode):
|
101 |
+
assert mode in ("min", "max")
|
102 |
+
self.mode = mode
|
103 |
+
# min: lower -> better; max: higher -> better
|
104 |
+
self.best_value = np.inf if mode == "min" else -np.inf
|
105 |
+
|
106 |
+
def compare(self, x, best_x):
|
107 |
+
return x < best_x if self.mode == "min" else x > best_x
|
108 |
+
|
109 |
+
def __call__(self, x):
|
110 |
+
if self.compare(x, self.best_value):
|
111 |
+
self.best_value = x
|
112 |
+
return True
|
113 |
+
return False
|
114 |
+
|
115 |
+
def state_dict(self):
|
116 |
+
return self.__dict__
|
117 |
+
|
118 |
+
def load_state_dict(self, state_dict):
|
119 |
+
self.__dict__.update(state_dict)
|
120 |
+
|
121 |
+
|
122 |
+
def fix_batchnorm(model: torch.nn.Module):
|
123 |
+
def inner(module):
|
124 |
+
class_name = module.__class__.__name__
|
125 |
+
if class_name.find("BatchNorm") != -1:
|
126 |
+
module.eval()
|
127 |
+
model.apply(inner)
|
128 |
+
|
129 |
+
|
130 |
+
def load_pretrained_model(model: torch.nn.Module,
|
131 |
+
pretrained: Union[str, Dict],
|
132 |
+
output_fn: Callable = sys.stdout.write):
|
133 |
+
if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
|
134 |
+
output_fn(f"pretrained {pretrained} not exist!")
|
135 |
+
return
|
136 |
+
|
137 |
+
if hasattr(model, "load_pretrained"):
|
138 |
+
model.load_pretrained(pretrained)
|
139 |
+
return
|
140 |
+
|
141 |
+
if isinstance(pretrained, dict):
|
142 |
+
state_dict = pretrained
|
143 |
+
else:
|
144 |
+
state_dict = torch.load(pretrained, map_location="cpu")
|
145 |
+
|
146 |
+
if "model" in state_dict:
|
147 |
+
state_dict = state_dict["model"]
|
148 |
+
model_dict = model.state_dict()
|
149 |
+
pretrained_dict = {
|
150 |
+
k: v for k, v in state_dict.items() if (k in model_dict) and (
|
151 |
+
model_dict[k].shape == v.shape)
|
152 |
+
}
|
153 |
+
output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
|
154 |
+
model_dict.update(pretrained_dict)
|
155 |
+
model.load_state_dict(model_dict, strict=True)
|
156 |
+
|
157 |
+
|
158 |
+
class AveragedModel(torch_average_model):
|
159 |
+
|
160 |
+
def update_parameters(self, model):
|
161 |
+
for p_swa, p_model in zip(self.parameters(), model.parameters()):
|
162 |
+
device = p_swa.device
|
163 |
+
p_model_ = p_model.detach().to(device)
|
164 |
+
if self.n_averaged == 0:
|
165 |
+
p_swa.detach().copy_(p_model_)
|
166 |
+
else:
|
167 |
+
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
|
168 |
+
self.n_averaged.to(device)))
|
169 |
+
|
170 |
+
for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
|
171 |
+
device = b_swa.device
|
172 |
+
b_model_ = b_model.detach().to(device)
|
173 |
+
if self.n_averaged == 0:
|
174 |
+
b_swa.detach().copy_(b_model_)
|
175 |
+
else:
|
176 |
+
b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
|
177 |
+
self.n_averaged.to(device)))
|
178 |
+
self.n_averaged += 1
|
audio_to_text/captioning/utils/word2vec/create_word_embedding.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
#!/usr/bin/env python3
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
import gensim
|
8 |
+
from gensim.models import Word2Vec
|
9 |
+
from tqdm import tqdm
|
10 |
+
import fire
|
11 |
+
|
12 |
+
import sys
|
13 |
+
import os
|
14 |
+
sys.path.append(os.getcwd())
|
15 |
+
from utils.build_vocab import Vocabulary
|
16 |
+
|
17 |
+
def create_embedding(vocab_file: str,
|
18 |
+
embed_size: int,
|
19 |
+
output: str,
|
20 |
+
caption_file: str = None,
|
21 |
+
pretrained_weights_path: str = None,
|
22 |
+
**word2vec_kwargs):
|
23 |
+
vocabulary = torch.load(vocab_file, map_location="cpu")
|
24 |
+
|
25 |
+
if pretrained_weights_path:
|
26 |
+
model = gensim.models.KeyedVectors.load_word2vec_format(
|
27 |
+
fname=pretrained_weights_path,
|
28 |
+
binary=True,
|
29 |
+
)
|
30 |
+
if model.vector_size != embed_size:
|
31 |
+
assert embed_size < model.vector_size, f"only reduce dimension, cannot add dimesion {model.vector_size} to {embed_size}"
|
32 |
+
from sklearn.decomposition import PCA
|
33 |
+
pca = PCA(n_components=embed_size)
|
34 |
+
model.vectors = pca.fit_transform(model.vectors)
|
35 |
+
else:
|
36 |
+
caption_df = pd.read_json(caption_file)
|
37 |
+
caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
|
38 |
+
sentences = list(caption_df["tokens"].values)
|
39 |
+
epochs = word2vec_kwargs.get("epochs", 10)
|
40 |
+
if "epochs" in word2vec_kwargs:
|
41 |
+
del word2vec_kwargs["epochs"]
|
42 |
+
model = Word2Vec(size=embed_size, min_count=1, **word2vec_kwargs)
|
43 |
+
model.build_vocab(sentences=sentences)
|
44 |
+
model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
|
45 |
+
|
46 |
+
word_embeddings = np.random.randn(len(vocabulary), embed_size)
|
47 |
+
|
48 |
+
if isinstance(model, gensim.models.word2vec.Word2Vec):
|
49 |
+
model = model.wv
|
50 |
+
with tqdm(total=len(vocabulary), ascii=True) as pbar:
|
51 |
+
for word, idx in vocabulary.word2idx.items():
|
52 |
+
try:
|
53 |
+
word_embeddings[idx] = model.get_vector(word)
|
54 |
+
except KeyError:
|
55 |
+
print(f"word {word} not found in word2vec model, it is random initialized!")
|
56 |
+
pbar.update()
|
57 |
+
|
58 |
+
np.save(output, word_embeddings)
|
59 |
+
|
60 |
+
print("Finish writing word2vec embeddings to " + output)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
fire.Fire(create_embedding)
|
65 |
+
|
66 |
+
|
67 |
+
|
audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
encoder:
|
3 |
+
type: Cnn14RnnEncoder
|
4 |
+
args:
|
5 |
+
sample_rate: 32000
|
6 |
+
pretrained: ./audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
|
7 |
+
freeze_cnn: True
|
8 |
+
freeze_cnn_bn: True
|
9 |
+
bidirectional: True
|
10 |
+
dropout: 0.5
|
11 |
+
hidden_size: 256
|
12 |
+
num_layers: 3
|
13 |
+
decoder:
|
14 |
+
type: TransformerDecoder
|
15 |
+
args:
|
16 |
+
attn_emb_dim: 512
|
17 |
+
dropout: 0.2
|
18 |
+
emb_dim: 256
|
19 |
+
fc_emb_dim: 512
|
20 |
+
nlayers: 2
|
21 |
+
type: TransformerModel
|
22 |
+
args: {}
|
audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8d341dccafcdcfb7009c402afb07f314ab1d613a5f5c42d32407d6c2a821abf
|
3 |
+
size 41755865
|
audio_to_text/inference_waveform.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import audio_to_text.captioning.models
|
7 |
+
import audio_to_text.captioning.models.encoder
|
8 |
+
import audio_to_text.captioning.models.decoder
|
9 |
+
import audio_to_text.captioning.utils.train_util as train_util
|
10 |
+
|
11 |
+
|
12 |
+
def load_model(config, checkpoint):
|
13 |
+
ckpt = torch.load(checkpoint, "cpu")
|
14 |
+
encoder_cfg = config["model"]["encoder"]
|
15 |
+
encoder = train_util.init_obj(
|
16 |
+
audio_to_text.captioning.models.encoder,
|
17 |
+
encoder_cfg
|
18 |
+
)
|
19 |
+
if "pretrained" in encoder_cfg:
|
20 |
+
pretrained = encoder_cfg["pretrained"]
|
21 |
+
train_util.load_pretrained_model(encoder,
|
22 |
+
pretrained,
|
23 |
+
sys.stdout.write)
|
24 |
+
decoder_cfg = config["model"]["decoder"]
|
25 |
+
if "vocab_size" not in decoder_cfg["args"]:
|
26 |
+
decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
|
27 |
+
decoder = train_util.init_obj(
|
28 |
+
audio_to_text.captioning.models.decoder,
|
29 |
+
decoder_cfg
|
30 |
+
)
|
31 |
+
if "word_embedding" in decoder_cfg:
|
32 |
+
decoder.load_word_embedding(**decoder_cfg["word_embedding"])
|
33 |
+
if "pretrained" in decoder_cfg:
|
34 |
+
pretrained = decoder_cfg["pretrained"]
|
35 |
+
train_util.load_pretrained_model(decoder,
|
36 |
+
pretrained,
|
37 |
+
sys.stdout.write)
|
38 |
+
model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
|
39 |
+
encoder=encoder, decoder=decoder)
|
40 |
+
train_util.load_pretrained_model(model, ckpt)
|
41 |
+
model.eval()
|
42 |
+
return {
|
43 |
+
"model": model,
|
44 |
+
"vocabulary": ckpt["vocabulary"]
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def decode_caption(word_ids, vocabulary):
|
49 |
+
candidate = []
|
50 |
+
for word_id in word_ids:
|
51 |
+
word = vocabulary[word_id]
|
52 |
+
if word == "<end>":
|
53 |
+
break
|
54 |
+
elif word == "<start>":
|
55 |
+
continue
|
56 |
+
candidate.append(word)
|
57 |
+
candidate = " ".join(candidate)
|
58 |
+
return candidate
|
59 |
+
|
60 |
+
|
61 |
+
class AudioCapModel(object):
|
62 |
+
def __init__(self,weight_dir,device='cuda'):
|
63 |
+
config = os.path.join(weight_dir,'config.yaml')
|
64 |
+
self.config = train_util.parse_config_or_kwargs(config)
|
65 |
+
checkpoint = os.path.join(weight_dir,'swa.pth')
|
66 |
+
resumed = load_model(self.config, checkpoint)
|
67 |
+
model = resumed["model"]
|
68 |
+
self.vocabulary = resumed["vocabulary"]
|
69 |
+
self.model = model.to(device)
|
70 |
+
self.device = device
|
71 |
+
|
72 |
+
def caption(self,audio_list):
|
73 |
+
if isinstance(audio_list,np.ndarray):
|
74 |
+
audio_list = [audio_list]
|
75 |
+
elif isinstance(audio_list,str):
|
76 |
+
audio_list = [librosa.load(audio_list,sr=32000)[0]]
|
77 |
+
|
78 |
+
captions = []
|
79 |
+
for wav in audio_list:
|
80 |
+
inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
|
81 |
+
wav_len = torch.LongTensor([len(wav)])
|
82 |
+
input_dict = {
|
83 |
+
"mode": "inference",
|
84 |
+
"wav": inputwav,
|
85 |
+
"wav_len": wav_len,
|
86 |
+
"specaug": False,
|
87 |
+
"sample_method": "beam",
|
88 |
+
}
|
89 |
+
print(input_dict)
|
90 |
+
out_dict = self.model(input_dict)
|
91 |
+
caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
|
92 |
+
out_dict["seq"].cpu().numpy()]
|
93 |
+
captions.extend(caption_batch)
|
94 |
+
return captions
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
def __call__(self, audio_list):
|
99 |
+
return self.caption(audio_list)
|
100 |
+
|
101 |
+
|
102 |
+
|
audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c4faa86f30e77df235b5dc1fb6578a18ff2b8a1b0043f47e30acb9ccb53a336
|
3 |
+
size 494977221
|