Spaces:
Runtime error
Runtime error
Initial Commit
Browse files- .gitattributes +1 -0
- README.md +6 -6
- cfg/audiocaps/base.yaml +48 -0
- cfg/audiocaps/large.yaml +48 -0
- cfg/audiocaps_args.yaml +60 -0
- cfg/clotho/base.yaml +48 -0
- cfg/clotho/large.yaml +48 -0
- cfg/clotho_finetune/base.yaml +48 -0
- cfg/clotho_finetune/large.yaml +48 -0
- ckpt/config.json +75 -0
- ckpt/pytorch_model.bin +3 -0
- csv/audiocaps/test.csv +0 -0
- csv/audiocaps/train.csv +0 -0
- csv/audiocaps/valid.csv +0 -0
- csv/clotho/test.csv +0 -0
- csv/clotho/train.csv +0 -0
- csv/clotho/valid.csv +0 -0
- data/__init__.py +0 -0
- data/collator.py +61 -0
- data/infer_clap.py +67 -0
- data/infer_encodec.py +41 -0
- data/preprocess.py +232 -0
- gradio_app.py +60 -0
- inference.ipynb +0 -0
- inference.py +161 -0
- metric/__init__.py +0 -0
- metric/compute_metric.py +24 -0
- metric/compute_metric_from_scratch.py +70 -0
- metric/make_predictions.py +41 -0
- modeling/__init__.py +0 -0
- modeling/enclap_bart.py +548 -0
- modeling/modeling_outputs.py +11 -0
- port_weights.py +42 -0
- requirements.txt +14 -0
- test/bart_test.ipynb +363 -0
- test/clap_test.ipynb +0 -0
- test/dataset_test.ipynb +0 -0
- test/dataset_test.py +57 -0
- test/encodec_test.ipynb +0 -0
- test/encodec_test.py +24 -0
- test/eval_dataset_test.py +42 -0
- test/masking_test.ipynb +117 -0
- test/metric_test.ipynb +260 -0
- test/subsample_test.ipynb +121 -0
- test/train_test.ipynb +261 -0
- test/train_test.py +41 -0
- train.py +527 -0
- train.sh +2 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ckpt/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
title: Enclap
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Enclap
|
3 |
+
emoji: 🔊
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.41.0
|
8 |
+
app_file: gradio_app.py
|
9 |
pinned: false
|
10 |
+
license: openrail
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
cfg/audiocaps/base.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /output
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: csv/audiocaps/train.csv
|
7 |
+
validation_file: csv/audiocaps/valid.csv
|
8 |
+
encodec_base_path: /data/audiocaps/encodec
|
9 |
+
clap_base_path: /data/audiocaps/clap
|
10 |
+
tokenizer_name: facebook/bart-base
|
11 |
+
config_name_or_path: facebook/bart-base
|
12 |
+
model_name_or_path: facebook/bart-base
|
13 |
+
eval_num_captions: 5
|
14 |
+
overwrite_output_dir: False
|
15 |
+
|
16 |
+
# Basic Config
|
17 |
+
encodec_masking_prob: 0.15
|
18 |
+
encodec_masking_span: 10
|
19 |
+
num_train_epochs: 15
|
20 |
+
max_train_steps: null
|
21 |
+
gradient_accumulation_steps: 1
|
22 |
+
per_device_train_batch_size: 16
|
23 |
+
per_device_eval_batch_size: 16
|
24 |
+
split_batches: true
|
25 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
26 |
+
resume_from_checkpoint: null
|
27 |
+
|
28 |
+
# Generation Config
|
29 |
+
max_target_length: 128
|
30 |
+
val_max_target_length: 50
|
31 |
+
|
32 |
+
# Training Hyperparameters
|
33 |
+
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
34 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
35 |
+
lr_scheduler_type: inverse_sqrt
|
36 |
+
learning_rate: 6.5e-5 # peak lr
|
37 |
+
num_warmup_steps: 2000
|
38 |
+
weight_decay: 0.01
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
# Others
|
42 |
+
with_tracking: true
|
43 |
+
report_to: tensorboard
|
44 |
+
ignore_pad_token_for_loss: true
|
45 |
+
preprocessing_num_workers: 32
|
46 |
+
use_slow_tokenizer: false
|
47 |
+
overwrite_cache: false
|
48 |
+
pad_to_max_length: false
|
cfg/audiocaps/large.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /output
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: csv/audiocaps/train.csv
|
7 |
+
validation_file: csv/audiocaps/valid.csv
|
8 |
+
encodec_base_path: /data/audiocaps/encodec
|
9 |
+
clap_base_path: /data/audiocaps/clap
|
10 |
+
tokenizer_name: facebook/bart-large
|
11 |
+
config_name_or_path: facebook/bart-large
|
12 |
+
model_name_or_path: facebook/bart-large
|
13 |
+
eval_num_captions: 5
|
14 |
+
overwrite_output_dir: False
|
15 |
+
|
16 |
+
# Basic Config
|
17 |
+
encodec_masking_prob: 0.15
|
18 |
+
encodec_masking_span: 10
|
19 |
+
num_train_epochs: 15
|
20 |
+
max_train_steps: null
|
21 |
+
gradient_accumulation_steps: 1
|
22 |
+
per_device_train_batch_size: 64
|
23 |
+
per_device_eval_batch_size: 64
|
24 |
+
split_batches: true
|
25 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
26 |
+
resume_from_checkpoint: null
|
27 |
+
|
28 |
+
# Generation Config
|
29 |
+
max_target_length: 128
|
30 |
+
val_max_target_length: 50
|
31 |
+
|
32 |
+
# Training Hyperparameters
|
33 |
+
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
34 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
35 |
+
lr_scheduler_type: inverse_sqrt
|
36 |
+
learning_rate: 3e-5 # peak lr
|
37 |
+
num_warmup_steps: 2000
|
38 |
+
weight_decay: 0.01
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
# Others
|
42 |
+
with_tracking: true
|
43 |
+
report_to: tensorboard
|
44 |
+
ignore_pad_token_for_loss: true
|
45 |
+
preprocessing_num_workers: 32
|
46 |
+
use_slow_tokenizer: false
|
47 |
+
overwrite_cache: false
|
48 |
+
pad_to_max_length: false
|
cfg/audiocaps_args.yaml
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /data/jyk/aac_results/bart_large/audiocaps_3e5_gpu4_1115_2000
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: /workspace/audiobart/csv/AudioCaps/train.csv
|
7 |
+
validation_file: /workspace/audiobart/csv/AudioCaps/val.csv
|
8 |
+
test_file: /workspace/audiobart/csv/AudioCaps/test.csv
|
9 |
+
base_path: /data/jyk/aac_dataset/AudioCaps/encodec_16
|
10 |
+
clap_base_path: /data/jyk/aac_dataset/AudioCaps/clap_audio_fused
|
11 |
+
tokenizer_name: facebook/bart-large
|
12 |
+
# model_name_or_path: /workspace/audiobart/bart/model
|
13 |
+
model_name_or_path: facebook/bart-large
|
14 |
+
num_captions: 5
|
15 |
+
overwrite_output_dir: False
|
16 |
+
|
17 |
+
|
18 |
+
# Training Configs
|
19 |
+
# Basic Config
|
20 |
+
max_encodec_length: 1022
|
21 |
+
only_encoder_epochs: 0
|
22 |
+
only_encodec_epochs: 0
|
23 |
+
clap_masking_prob: -1
|
24 |
+
encodec_masking_prob: 0.15
|
25 |
+
encodec_masking_length: 10
|
26 |
+
random_sampling: true
|
27 |
+
num_train_epochs: 30
|
28 |
+
max_train_steps: null
|
29 |
+
gradient_accumulation_steps: 1
|
30 |
+
per_device_train_batch_size: 64
|
31 |
+
per_device_eval_batch_size: 64
|
32 |
+
split_batches: true
|
33 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
34 |
+
resume_from_checkpoint: null
|
35 |
+
|
36 |
+
# Model & Generation Config
|
37 |
+
max_source_length: 1024
|
38 |
+
max_target_length: 128
|
39 |
+
val_max_target_length: 50
|
40 |
+
num_beams: null
|
41 |
+
pad_to_max_length: false
|
42 |
+
num_subsampling: 0
|
43 |
+
|
44 |
+
# Training Hyperparameters
|
45 |
+
learning_rate: 3e-5 # peak lr
|
46 |
+
# Should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
47 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
48 |
+
lr_scheduler_type: inverse_sqrt
|
49 |
+
# lr_scheduler_type: two_stage_inverse_sqrt
|
50 |
+
weight_decay: 0.01
|
51 |
+
num_warmup_steps: 2000
|
52 |
+
max_grad_norm: 1.0
|
53 |
+
|
54 |
+
# Do not Change
|
55 |
+
with_tracking: true
|
56 |
+
report_to: all
|
57 |
+
ignore_pad_token_for_loss: true
|
58 |
+
preprocessing_num_workers: 32
|
59 |
+
use_slow_tokenizer: false
|
60 |
+
overwrite_cache: false
|
cfg/clotho/base.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /output
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: /csv/clotho/train.csv
|
7 |
+
validation_file: /csv/clotho/valid.csv
|
8 |
+
encodec_base_path: /data/clotho/encodec
|
9 |
+
clap_base_path: /data/clotho/clap
|
10 |
+
tokenizer_name: facebook/bart-base
|
11 |
+
config_name_or_path: facebook/bart-base
|
12 |
+
model_name_or_path: facebook/bart-base
|
13 |
+
eval_num_captions: 5
|
14 |
+
overwrite_output_dir: False
|
15 |
+
|
16 |
+
# Basic Config
|
17 |
+
encodec_masking_prob: 0.15
|
18 |
+
encodec_masking_span: 10
|
19 |
+
num_train_epochs: 15
|
20 |
+
max_train_steps: null
|
21 |
+
gradient_accumulation_steps: 1
|
22 |
+
per_device_train_batch_size: 64
|
23 |
+
per_device_eval_batch_size: 64
|
24 |
+
split_batches: true
|
25 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
26 |
+
resume_from_checkpoint: null
|
27 |
+
|
28 |
+
# Generation Config
|
29 |
+
max_target_length: 128
|
30 |
+
val_max_target_length: 50
|
31 |
+
|
32 |
+
# Training Hyperparameters
|
33 |
+
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
34 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
35 |
+
lr_scheduler_type: inverse_sqrt
|
36 |
+
learning_rate: 4e-5 # peak lr
|
37 |
+
num_warmup_steps: 1000
|
38 |
+
weight_decay: 0.01
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
# Others
|
42 |
+
with_tracking: true
|
43 |
+
report_to: tensorboard
|
44 |
+
ignore_pad_token_for_loss: true
|
45 |
+
preprocessing_num_workers: 32
|
46 |
+
use_slow_tokenizer: false
|
47 |
+
overwrite_cache: false
|
48 |
+
pad_to_max_length: false
|
cfg/clotho/large.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /output
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: /csv/clotho/train.csv
|
7 |
+
validation_file: /csv/clotho/valid.csv
|
8 |
+
encodec_base_path: /data/clotho/encodec
|
9 |
+
clap_base_path: /data/clotho/clap
|
10 |
+
tokenizer_name: facebook/bart-large
|
11 |
+
config_name_or_path: facebook/bart-large
|
12 |
+
model_name_or_path: facebook/bart-large
|
13 |
+
eval_num_captions: 5
|
14 |
+
overwrite_output_dir: False
|
15 |
+
|
16 |
+
# Basic Config
|
17 |
+
encodec_masking_prob: 0.15
|
18 |
+
encodec_masking_span: 10
|
19 |
+
num_train_epochs: 15
|
20 |
+
max_train_steps: null
|
21 |
+
gradient_accumulation_steps: 1
|
22 |
+
per_device_train_batch_size: 64
|
23 |
+
per_device_eval_batch_size: 64
|
24 |
+
split_batches: true
|
25 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
26 |
+
resume_from_checkpoint: null
|
27 |
+
|
28 |
+
# Generation Config
|
29 |
+
max_target_length: 128
|
30 |
+
val_max_target_length: 50
|
31 |
+
|
32 |
+
# Training Hyperparameters
|
33 |
+
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
34 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
35 |
+
lr_scheduler_type: inverse_sqrt
|
36 |
+
learning_rate: 2.5e-5 # peak lr
|
37 |
+
num_warmup_steps: 1000
|
38 |
+
weight_decay: 0.01
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
# Others
|
42 |
+
with_tracking: true
|
43 |
+
report_to: tensorboard
|
44 |
+
ignore_pad_token_for_loss: true
|
45 |
+
preprocessing_num_workers: 32
|
46 |
+
use_slow_tokenizer: false
|
47 |
+
overwrite_cache: false
|
48 |
+
pad_to_max_length: false
|
cfg/clotho_finetune/base.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /output
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: /csv/clotho/train.csv
|
7 |
+
validation_file: /csv/clotho/valid.csv
|
8 |
+
encodec_base_path: /data/clotho/encodec
|
9 |
+
clap_base_path: /data/clotho/clap
|
10 |
+
tokenizer_name: facebook/bart-base
|
11 |
+
config_name_or_path: facebook/bart-base
|
12 |
+
model_name_or_path: /data/enclap_audiocaps
|
13 |
+
eval_num_captions: 5
|
14 |
+
overwrite_output_dir: False
|
15 |
+
|
16 |
+
# Basic Config
|
17 |
+
encodec_masking_prob: 0.15
|
18 |
+
encodec_masking_span: 10
|
19 |
+
num_train_epochs: 15
|
20 |
+
max_train_steps: null
|
21 |
+
gradient_accumulation_steps: 1
|
22 |
+
per_device_train_batch_size: 64
|
23 |
+
per_device_eval_batch_size: 64
|
24 |
+
split_batches: true
|
25 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
26 |
+
resume_from_checkpoint: null
|
27 |
+
|
28 |
+
# Generation Config
|
29 |
+
max_target_length: 128
|
30 |
+
val_max_target_length: 50
|
31 |
+
|
32 |
+
# Training Hyperparameters
|
33 |
+
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
34 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
35 |
+
lr_scheduler_type: inverse_sqrt
|
36 |
+
learning_rate: 2e-5 # peak lr
|
37 |
+
num_warmup_steps: 1000
|
38 |
+
weight_decay: 0.01
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
# Others
|
42 |
+
with_tracking: true
|
43 |
+
report_to: tensorboard
|
44 |
+
ignore_pad_token_for_loss: true
|
45 |
+
preprocessing_num_workers: 32
|
46 |
+
use_slow_tokenizer: false
|
47 |
+
overwrite_cache: false
|
48 |
+
pad_to_max_length: false
|
cfg/clotho_finetune/large.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiment Config for each experiment
|
2 |
+
output_dir: /output
|
3 |
+
logging_dir: runs/tb_log
|
4 |
+
logging_steps: 10
|
5 |
+
seed: 1115
|
6 |
+
train_file: /csv/clotho/train.csv
|
7 |
+
validation_file: /csv/clotho/valid.csv
|
8 |
+
encodec_base_path: /data/clotho/encodec
|
9 |
+
clap_base_path: /data/clotho/clap
|
10 |
+
tokenizer_name: facebook/bart-large
|
11 |
+
config_name_or_path: facebook/bart-large
|
12 |
+
model_name_or_path: /data/enclap_audiocaps
|
13 |
+
eval_num_captions: 5
|
14 |
+
overwrite_output_dir: False
|
15 |
+
|
16 |
+
# Basic Config
|
17 |
+
encodec_masking_prob: 0.15
|
18 |
+
encodec_masking_span: 10
|
19 |
+
num_train_epochs: 15
|
20 |
+
max_train_steps: null
|
21 |
+
gradient_accumulation_steps: 1
|
22 |
+
per_device_train_batch_size: 64
|
23 |
+
per_device_eval_batch_size: 64
|
24 |
+
split_batches: true
|
25 |
+
checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
|
26 |
+
resume_from_checkpoint: null
|
27 |
+
|
28 |
+
# Generation Config
|
29 |
+
max_target_length: 128
|
30 |
+
val_max_target_length: 50
|
31 |
+
|
32 |
+
# Training Hyperparameters
|
33 |
+
# "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
|
34 |
+
# "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
|
35 |
+
lr_scheduler_type: inverse_sqrt
|
36 |
+
learning_rate: 1.25e-5 # peak lr
|
37 |
+
num_warmup_steps: 1000
|
38 |
+
weight_decay: 0.01
|
39 |
+
max_grad_norm: 1.0
|
40 |
+
|
41 |
+
# Others
|
42 |
+
with_tracking: true
|
43 |
+
report_to: tensorboard
|
44 |
+
ignore_pad_token_for_loss: true
|
45 |
+
preprocessing_num_workers: 32
|
46 |
+
use_slow_tokenizer: false
|
47 |
+
overwrite_cache: false
|
48 |
+
pad_to_max_length: false
|
ckpt/config.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "facebook/bart-base",
|
3 |
+
"activation_dropout": 0.1,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"add_bias_logits": false,
|
6 |
+
"add_final_layer_norm": false,
|
7 |
+
"architectures": [
|
8 |
+
"BartModel"
|
9 |
+
],
|
10 |
+
"attention_dropout": 0.1,
|
11 |
+
"bos_token_id": 0,
|
12 |
+
"classif_dropout": 0.1,
|
13 |
+
"classifier_dropout": 0.0,
|
14 |
+
"d_model": 768,
|
15 |
+
"decoder_attention_heads": 12,
|
16 |
+
"decoder_ffn_dim": 3072,
|
17 |
+
"decoder_layerdrop": 0.0,
|
18 |
+
"decoder_layers": 6,
|
19 |
+
"decoder_start_token_id": 2,
|
20 |
+
"dropout": 0.1,
|
21 |
+
"early_stopping": true,
|
22 |
+
"encoder_attention_heads": 12,
|
23 |
+
"encoder_ffn_dim": 3072,
|
24 |
+
"encoder_layerdrop": 0.0,
|
25 |
+
"encoder_layers": 6,
|
26 |
+
"eos_token_id": 2,
|
27 |
+
"forced_bos_token_id": 0,
|
28 |
+
"forced_eos_token_id": 2,
|
29 |
+
"gradient_checkpointing": false,
|
30 |
+
"id2label": {
|
31 |
+
"0": "LABEL_0",
|
32 |
+
"1": "LABEL_1",
|
33 |
+
"2": "LABEL_2"
|
34 |
+
},
|
35 |
+
"init_std": 0.02,
|
36 |
+
"is_encoder_decoder": true,
|
37 |
+
"label2id": {
|
38 |
+
"LABEL_0": 0,
|
39 |
+
"LABEL_1": 1,
|
40 |
+
"LABEL_2": 2
|
41 |
+
},
|
42 |
+
"max_position_embeddings": 1024,
|
43 |
+
"model_type": "bart",
|
44 |
+
"no_repeat_ngram_size": 3,
|
45 |
+
"normalize_before": false,
|
46 |
+
"normalize_embedding": true,
|
47 |
+
"num_beams": 4,
|
48 |
+
"num_hidden_layers": 6,
|
49 |
+
"pad_token_id": 1,
|
50 |
+
"scale_embedding": false,
|
51 |
+
"task_specific_params": {
|
52 |
+
"summarization": {
|
53 |
+
"length_penalty": 1.0,
|
54 |
+
"max_length": 128,
|
55 |
+
"min_length": 12,
|
56 |
+
"num_beams": 4
|
57 |
+
},
|
58 |
+
"summarization_cnn": {
|
59 |
+
"length_penalty": 2.0,
|
60 |
+
"max_length": 142,
|
61 |
+
"min_length": 56,
|
62 |
+
"num_beams": 4
|
63 |
+
},
|
64 |
+
"summarization_xsum": {
|
65 |
+
"length_penalty": 1.0,
|
66 |
+
"max_length": 62,
|
67 |
+
"min_length": 11,
|
68 |
+
"num_beams": 6
|
69 |
+
}
|
70 |
+
},
|
71 |
+
"torch_dtype": "float32",
|
72 |
+
"transformers_version": "4.29.0",
|
73 |
+
"use_cache": true,
|
74 |
+
"vocab_size": 50265
|
75 |
+
}
|
ckpt/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fef9cc30fdf47b82a8bb846d418ac3ef893b4d10e909fafbbe3ed8a1931cf23
|
3 |
+
size 663433954
|
csv/audiocaps/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
csv/audiocaps/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
csv/audiocaps/valid.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
csv/clotho/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
csv/clotho/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
csv/clotho/valid.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/__init__.py
ADDED
File without changes
|
data/collator.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import BatchEncoding, DataCollatorForSeq2Seq
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class DataCollatorForEnClapBart(DataCollatorForSeq2Seq):
|
9 |
+
input_pad_token_id: int = 1024
|
10 |
+
num_rvq: int = 16
|
11 |
+
|
12 |
+
def __call__(self, features, return_tensors=None):
|
13 |
+
if return_tensors is None:
|
14 |
+
return_tensors = self.return_tensors
|
15 |
+
|
16 |
+
batch_size = len(features)
|
17 |
+
# stacked_features = {k: [f[k] for f in features] for k in features[0]}
|
18 |
+
clap_embedding = torch.Tensor(
|
19 |
+
[feature["clap_embedding"] for feature in features]
|
20 |
+
)
|
21 |
+
|
22 |
+
pad_token_id = self.tokenizer.pad_token_id
|
23 |
+
self.tokenizer.pad_token_id = self.input_pad_token_id
|
24 |
+
keys = ["input_ids", "mcm_labels"]
|
25 |
+
tmp_key_map = {"input_ids": "input_ids", "mcm_labels": "labels"}
|
26 |
+
input_features = super().__call__(
|
27 |
+
[
|
28 |
+
{tmp_key_map[key]: feature[key][:, i] for key in keys}
|
29 |
+
for feature in features
|
30 |
+
for i in range(feature[keys[0]].shape[-1])
|
31 |
+
],
|
32 |
+
return_tensors,
|
33 |
+
)
|
34 |
+
|
35 |
+
self.tokenizer.pad_token_id = 1
|
36 |
+
keys = ["encodec_mask", "attention_mask", "labels"]
|
37 |
+
tmp_key_map = {
|
38 |
+
"encodec_mask": "input_ids",
|
39 |
+
"attention_mask": "attention_mask",
|
40 |
+
"labels": "labels",
|
41 |
+
}
|
42 |
+
other_features = super().__call__(
|
43 |
+
[{tmp_key_map[key]: feature[key] for key in keys} for feature in features],
|
44 |
+
return_tensors,
|
45 |
+
)
|
46 |
+
self.tokenizer.pad_token_id = pad_token_id
|
47 |
+
|
48 |
+
return BatchEncoding(
|
49 |
+
{
|
50 |
+
"input_ids": input_features["input_ids"]
|
51 |
+
.reshape(batch_size, self.num_rvq, -1)
|
52 |
+
.transpose(1, 2),
|
53 |
+
"mcm_labels": input_features["labels"]
|
54 |
+
.reshape(batch_size, self.num_rvq, -1)
|
55 |
+
.transpose(1, 2),
|
56 |
+
"attention_mask": other_features["attention_mask"],
|
57 |
+
"encodec_mask": other_features["input_ids"],
|
58 |
+
"labels": other_features["labels"],
|
59 |
+
"clap_embedding": clap_embedding,
|
60 |
+
}
|
61 |
+
)
|
data/infer_clap.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from laion_clap import CLAP_Module
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument(
|
13 |
+
"--data_path",
|
14 |
+
"-d",
|
15 |
+
required=True,
|
16 |
+
type=str,
|
17 |
+
help="Path of the original wav files",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--save_path",
|
21 |
+
"-s",
|
22 |
+
required=True,
|
23 |
+
type=str,
|
24 |
+
help="Path to save the clap audio embedding '.npy' files",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--clap_ckpt",
|
28 |
+
"-c",
|
29 |
+
required=True,
|
30 |
+
type=str,
|
31 |
+
help="Path of the pretrained clap checkpoint",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--enable_fusion",
|
35 |
+
"-e",
|
36 |
+
default=True,
|
37 |
+
type=bool,
|
38 |
+
help="Whether to enable the feature fusion of the clap model. Depends on the clap checkpoint you are using",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--audio_encoder",
|
42 |
+
"-a",
|
43 |
+
default="HTSAT-tiny",
|
44 |
+
type=str,
|
45 |
+
help="Audio encoder of the clap model. Depends on the clap checkpoint you are using",
|
46 |
+
)
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
model = CLAP_Module(enable_fusion=args.enable_fusion, aencoder=args.audio_encoder)
|
50 |
+
model.load_ckpt(args.clap_ckpt)
|
51 |
+
data_path = Path(args.data_path)
|
52 |
+
save_path = Path(args.save_path)
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
for wav_path in tqdm(
|
56 |
+
data_path.glob("**/*.wav"), dynamic_ncols=True, colour="yellow"
|
57 |
+
):
|
58 |
+
wav, _ = librosa.load(wav_path, sr=48000)
|
59 |
+
|
60 |
+
clap_embeding = model.get_audio_embedding_from_data(
|
61 |
+
x=wav[np.newaxis], use_tensor=False
|
62 |
+
)
|
63 |
+
clap_embeding = clap_embeding.squeeze(axis=0)
|
64 |
+
|
65 |
+
out_path = save_path / wav_path.with_suffix(".npy").relative_to(data_path)
|
66 |
+
out_path.parent.mkdir(exist_ok=True)
|
67 |
+
np.save(out_path, clap_embeding)
|
data/infer_encodec.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
from encodec import EncodecModel
|
8 |
+
from encodec.utils import convert_audio
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument(
|
14 |
+
"--data_path", type=str, required=True, help="Path of the original wav files"
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--save_path", type=str, required=True, help="Path to save encodec .npy files"
|
18 |
+
)
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
model = EncodecModel.encodec_model_24khz()
|
23 |
+
model.set_target_bandwidth(12.0)
|
24 |
+
model = model.to(device)
|
25 |
+
|
26 |
+
data_path = Path(args.data_path)
|
27 |
+
save_path = Path(args.save_path)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
for wav_path in tqdm(data_path.glob("**/*.wav")):
|
31 |
+
wav, sr = torchaudio.load(wav_path)
|
32 |
+
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
33 |
+
wav = wav.unsqueeze(0).to(device)
|
34 |
+
encoded_frames = model.encode(wav)
|
35 |
+
|
36 |
+
codes = torch.cat([codebook for codebook, _ in encoded_frames], dim=-1)
|
37 |
+
codes = codes.cpu().squeeze(0).transpose(-1, -2).detach().numpy()
|
38 |
+
|
39 |
+
out_path = save_path / wav_path.with_suffix(".npy").relative_to(data_path)
|
40 |
+
out_path.parent.mkdir(exist_ok=True)
|
41 |
+
np.save(out_path, codes)
|
data/preprocess.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
from random import randint
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from transformers import BartTokenizerFast
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Preprocessor:
|
13 |
+
encodec_base_path: Path
|
14 |
+
clap_base_path: Path
|
15 |
+
tokenizer: BartTokenizerFast = BartTokenizerFast.from_pretrained(
|
16 |
+
"facebook/bart-base"
|
17 |
+
)
|
18 |
+
max_length: int = 1024
|
19 |
+
mcm_masking_prob: float = 0.15
|
20 |
+
mcm_masking_span: int = 10
|
21 |
+
label_pad_token_id: int = -100
|
22 |
+
mask_token_id: int = 1024
|
23 |
+
num_eval_captions: int = 5
|
24 |
+
|
25 |
+
def __post_init__(self):
|
26 |
+
if isinstance(self.encodec_base_path, str):
|
27 |
+
self.encodec_base_path = Path(self.encodec_base_path)
|
28 |
+
if isinstance(self.clap_base_path, str):
|
29 |
+
self.clap_base_path = Path(self.clap_base_path)
|
30 |
+
if isinstance(self.tokenizer, str):
|
31 |
+
self.tokenizer = BartTokenizerFast.from_pretrained(self.tokenizer)
|
32 |
+
|
33 |
+
def preprocess_train(self, example):
|
34 |
+
path = example["file_path"]
|
35 |
+
encodec = np.load(self.encodec_base_path / path)
|
36 |
+
clap_embedding = np.load(self.clap_base_path / path)
|
37 |
+
encodec_mask = np.array(
|
38 |
+
[0, 0] + [1] * min(encodec.shape[0], self.max_length - 3) + [0]
|
39 |
+
)
|
40 |
+
attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype(
|
41 |
+
np.int64
|
42 |
+
)
|
43 |
+
target_text = self.tokenizer(text_target=example["caption"])
|
44 |
+
|
45 |
+
if encodec.shape[0] + 3 > self.max_length:
|
46 |
+
start = randint(0, encodec.shape[0] - self.max_length + 3)
|
47 |
+
encodec = encodec[start : start + self.max_length - 3]
|
48 |
+
|
49 |
+
mcm_labels = None
|
50 |
+
if self.mcm_masking_prob > 0:
|
51 |
+
num_rvq = encodec.shape[-1]
|
52 |
+
mcm_mask, _ = _compute_mask_indices(
|
53 |
+
encodec.T.shape, self.mcm_masking_prob, self.mcm_masking_span
|
54 |
+
)
|
55 |
+
mcm_mask = mcm_mask.T
|
56 |
+
mcm_labels = np.where(mcm_mask, encodec, self.label_pad_token_id)
|
57 |
+
mcm_labels = np.concatenate(
|
58 |
+
[
|
59 |
+
np.ones((2, num_rvq), dtype=np.int64) * self.label_pad_token_id,
|
60 |
+
mcm_labels,
|
61 |
+
np.ones((1, num_rvq), dtype=np.int64) * self.label_pad_token_id,
|
62 |
+
],
|
63 |
+
axis=0,
|
64 |
+
)
|
65 |
+
encodec[mcm_mask] = self.mask_token_id
|
66 |
+
|
67 |
+
encodec = np.concatenate(
|
68 |
+
[
|
69 |
+
np.ones((2, num_rvq), dtype=np.int64) * self.tokenizer.bos_token_id,
|
70 |
+
encodec,
|
71 |
+
np.ones((1, num_rvq), dtype=np.int64) * self.tokenizer.eos_token_id,
|
72 |
+
],
|
73 |
+
axis=0,
|
74 |
+
)
|
75 |
+
|
76 |
+
return {
|
77 |
+
"input_ids": encodec,
|
78 |
+
"clap_embedding": clap_embedding,
|
79 |
+
"encodec_mask": encodec_mask,
|
80 |
+
"attention_mask": attention_mask,
|
81 |
+
"mcm_labels": mcm_labels,
|
82 |
+
"labels": target_text["input_ids"],
|
83 |
+
}
|
84 |
+
|
85 |
+
def preprocess_eval(self, example):
|
86 |
+
path = example["file_path"]
|
87 |
+
encodec = np.load(self.encodec_base_path / path)
|
88 |
+
clap_embedding = np.load(self.clap_base_path / path)
|
89 |
+
attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype(
|
90 |
+
np.int64
|
91 |
+
)
|
92 |
+
|
93 |
+
if encodec.shape[0] + 3 > self.max_length:
|
94 |
+
encodec = encodec[: self.max_length - 3]
|
95 |
+
|
96 |
+
captions = []
|
97 |
+
for i in range(self.num_eval_captions):
|
98 |
+
captions.append(example[f"caption_{i+1}"])
|
99 |
+
|
100 |
+
return {
|
101 |
+
"input_ids": encodec,
|
102 |
+
"attention_mask": attention_mask,
|
103 |
+
"clap": clap_embedding,
|
104 |
+
"captions": captions,
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
def _compute_mask_indices(
|
109 |
+
shape: Tuple[int, int],
|
110 |
+
mask_prob: float,
|
111 |
+
mask_length: int,
|
112 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
113 |
+
min_masks: int = 0,
|
114 |
+
) -> np.ndarray:
|
115 |
+
"""
|
116 |
+
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
117 |
+
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
118 |
+
CPU as part of the preprocessing during training.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
122 |
+
the first element is the batch size and the second element is the length of the axis to span.
|
123 |
+
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
124 |
+
independently generated mask spans of length `mask_length` is computed by
|
125 |
+
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
126 |
+
actual percentage will be smaller.
|
127 |
+
mask_length: size of the mask
|
128 |
+
min_masks: minimum number of masked spans
|
129 |
+
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
130 |
+
each batch dimension.
|
131 |
+
"""
|
132 |
+
batch_size, sequence_length = shape
|
133 |
+
|
134 |
+
if mask_length < 1:
|
135 |
+
raise ValueError("`mask_length` has to be bigger than 0.")
|
136 |
+
|
137 |
+
if mask_length > sequence_length:
|
138 |
+
raise ValueError(
|
139 |
+
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
140 |
+
f" and `sequence_length`: {sequence_length}`"
|
141 |
+
)
|
142 |
+
|
143 |
+
# epsilon is used for probabilistic rounding
|
144 |
+
epsilon = np.random.rand(1).item()
|
145 |
+
|
146 |
+
def compute_num_masked_span(input_length):
|
147 |
+
"""Given input length, compute how many spans should be masked"""
|
148 |
+
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
149 |
+
num_masked_span = max(num_masked_span, min_masks)
|
150 |
+
|
151 |
+
# make sure num masked span <= sequence_length
|
152 |
+
if num_masked_span * mask_length > sequence_length:
|
153 |
+
num_masked_span = sequence_length // mask_length
|
154 |
+
|
155 |
+
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
156 |
+
if input_length - (mask_length - 1) < num_masked_span:
|
157 |
+
num_masked_span = max(input_length - (mask_length - 1), 0)
|
158 |
+
|
159 |
+
return num_masked_span
|
160 |
+
|
161 |
+
# compute number of masked spans in batch
|
162 |
+
input_lengths = (
|
163 |
+
attention_mask.sum(-1).detach().tolist()
|
164 |
+
if attention_mask is not None
|
165 |
+
else [sequence_length for _ in range(batch_size)]
|
166 |
+
)
|
167 |
+
|
168 |
+
# SpecAugment mask to fill
|
169 |
+
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
170 |
+
spec_aug_mask_idxs = []
|
171 |
+
|
172 |
+
max_num_masked_span = compute_num_masked_span(sequence_length)
|
173 |
+
|
174 |
+
if max_num_masked_span == 0:
|
175 |
+
return spec_aug_mask
|
176 |
+
|
177 |
+
for input_length in input_lengths:
|
178 |
+
# compute num of masked spans for this input
|
179 |
+
num_masked_span = compute_num_masked_span(input_length)
|
180 |
+
|
181 |
+
# get random indices to mask
|
182 |
+
spec_aug_mask_idx = np.random.choice(
|
183 |
+
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
184 |
+
)
|
185 |
+
|
186 |
+
# pick first sampled index that will serve as a dummy index to pad vector
|
187 |
+
# to ensure same dimension for all batches due to probabilistic rounding
|
188 |
+
# Picking first sample just pads those vectors twice.
|
189 |
+
if len(spec_aug_mask_idx) == 0:
|
190 |
+
# this case can only happen if `input_length` is strictly smaller then
|
191 |
+
# `sequence_length` in which case the last token has to be a padding
|
192 |
+
# token which we can use as a dummy mask id
|
193 |
+
dummy_mask_idx = sequence_length - 1
|
194 |
+
else:
|
195 |
+
dummy_mask_idx = spec_aug_mask_idx[0]
|
196 |
+
|
197 |
+
spec_aug_mask_idx = np.concatenate(
|
198 |
+
[
|
199 |
+
spec_aug_mask_idx,
|
200 |
+
np.ones(max_num_masked_span - num_masked_span, dtype=np.int32)
|
201 |
+
* dummy_mask_idx,
|
202 |
+
]
|
203 |
+
)
|
204 |
+
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
205 |
+
|
206 |
+
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
207 |
+
|
208 |
+
# expand masked indices to masked spans
|
209 |
+
spec_aug_mask_idxs = np.broadcast_to(
|
210 |
+
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
211 |
+
)
|
212 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(
|
213 |
+
batch_size, max_num_masked_span * mask_length
|
214 |
+
)
|
215 |
+
|
216 |
+
# add offset to the starting indexes so that indexes now create a span
|
217 |
+
offsets = np.arange(mask_length)[None, None, :]
|
218 |
+
offsets = np.broadcast_to(
|
219 |
+
offsets, (batch_size, max_num_masked_span, mask_length)
|
220 |
+
).reshape(batch_size, max_num_masked_span * mask_length)
|
221 |
+
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
222 |
+
|
223 |
+
# ensure that we cannot have indices larger than sequence_length
|
224 |
+
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
225 |
+
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = (
|
226 |
+
sequence_length - 1
|
227 |
+
)
|
228 |
+
|
229 |
+
# scatter indices to mask
|
230 |
+
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
231 |
+
|
232 |
+
return torch.from_numpy(spec_aug_mask), spec_aug_mask_idxs
|
gradio_app.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from transformers import AutoProcessor
|
8 |
+
|
9 |
+
from inference import EnClap
|
10 |
+
|
11 |
+
|
12 |
+
def input_toggle(choice: str):
|
13 |
+
if choice == "file":
|
14 |
+
return gr.update(visible=True), gr.update(visible=False)
|
15 |
+
return gr.update(visible=False), gr.update(visible=True)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
import logging
|
20 |
+
|
21 |
+
logging.getLogger().setLevel(logging.INFO)
|
22 |
+
ckpt_path = "./ckpt" # os.getenv("ckpt_path")
|
23 |
+
device = "cpu" # os.getenv("device")
|
24 |
+
|
25 |
+
enclap = EnClap(ckpt_path=ckpt_path, device=device)
|
26 |
+
|
27 |
+
def run_enclap(
|
28 |
+
input_type: str,
|
29 |
+
file_input: Tuple[int, np.ndarray],
|
30 |
+
mic_input: Tuple[int, np.ndarray],
|
31 |
+
seed: int,
|
32 |
+
) -> str:
|
33 |
+
print(input_type, file_input, mic_input)
|
34 |
+
input = file_input if input_type == "file" else mic_input
|
35 |
+
if input is None:
|
36 |
+
raise gr.Error("Input audio was not provided.")
|
37 |
+
res, audio = input
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
return enclap.infer_from_audio(torch.from_numpy(audio), res)[0]
|
40 |
+
|
41 |
+
with gr.Blocks() as demo:
|
42 |
+
with gr.Row():
|
43 |
+
with gr.Column():
|
44 |
+
radio = gr.Radio(
|
45 |
+
["file", "mic"],
|
46 |
+
value="file",
|
47 |
+
label="Choose the input method of the audio.",
|
48 |
+
)
|
49 |
+
file = gr.Audio(label="Input", visible=True)
|
50 |
+
mic = gr.Mic(label="Input", visible=False)
|
51 |
+
slider = gr.Slider(minimum=0, maximum=100, label="Seed")
|
52 |
+
radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
|
53 |
+
button = gr.Button("Run", label="run")
|
54 |
+
with gr.Column():
|
55 |
+
output = gr.Text(label="Output")
|
56 |
+
button.click(
|
57 |
+
fn=run_enclap, inputs=[radio, file, mic, slider], outputs=output
|
58 |
+
)
|
59 |
+
|
60 |
+
demo.launch()
|
inference.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
inference.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from encodec import EncodecModel
|
7 |
+
from encodec.utils import convert_audio
|
8 |
+
from laion_clap import CLAP_Module
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
|
11 |
+
from modeling.enclap_bart import EnClapBartConfig, EnClapBartForConditionalGeneration
|
12 |
+
|
13 |
+
|
14 |
+
class EnClap:
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
ckpt_path: str,
|
18 |
+
clap_audio_model: str = "HTSAT-tiny",
|
19 |
+
clap_enable_fusion = True,
|
20 |
+
device: str = "cuda",
|
21 |
+
):
|
22 |
+
config = EnClapBartConfig.from_pretrained(ckpt_path)
|
23 |
+
self.device = device
|
24 |
+
self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
|
25 |
+
self.model = (
|
26 |
+
EnClapBartForConditionalGeneration.from_pretrained(ckpt_path)
|
27 |
+
.to(self.device)
|
28 |
+
.eval()
|
29 |
+
)
|
30 |
+
|
31 |
+
self.encodec = EncodecModel.encodec_model_24khz().to(self.device)
|
32 |
+
self.encodec.set_target_bandwidth(12.0)
|
33 |
+
|
34 |
+
self.clap_model = CLAP_Module(enable_fusion=clap_enable_fusion, amodel=clap_audio_model, device=self.device)
|
35 |
+
self.clap_model.load_ckpt()
|
36 |
+
|
37 |
+
self.generation_config = {
|
38 |
+
"_from_model_config": True,
|
39 |
+
"bos_token_id": 0,
|
40 |
+
"decoder_start_token_id": 2,
|
41 |
+
"early_stopping": True,
|
42 |
+
"eos_token_id": 2,
|
43 |
+
"forced_bos_token_id": 0,
|
44 |
+
"forced_eos_token_id": 2,
|
45 |
+
"no_repeat_ngram_size": 3,
|
46 |
+
"num_beams": 4,
|
47 |
+
"pad_token_id": 1,
|
48 |
+
"max_length": 50,
|
49 |
+
}
|
50 |
+
self.scale_factor = 2**15
|
51 |
+
self.max_seq_len = config.max_position_embeddings - 3
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def infer_from_audio_file(
|
55 |
+
self, audio_file: str, generation_config: Dict[str, Any] = None
|
56 |
+
) -> str:
|
57 |
+
if generation_config is None:
|
58 |
+
generation_config = self.generation_config
|
59 |
+
audio, res = torchaudio.load(audio_file)
|
60 |
+
return self.infer_from_audio(audio[0], res)
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def infer_from_audio(
|
64 |
+
self, audio: torch.Tensor, res: int, generation_config: Dict[str, Any] = None
|
65 |
+
) -> str:
|
66 |
+
if generation_config is None:
|
67 |
+
generation_config = self.generation_config
|
68 |
+
if audio.dtype == torch.int or audio.dtype == torch.short:
|
69 |
+
audio = audio / self.scale_factor
|
70 |
+
encodec_audio = (
|
71 |
+
convert_audio(
|
72 |
+
audio.unsqueeze(0), res, self.encodec.sample_rate, self.encodec.channels
|
73 |
+
)
|
74 |
+
.unsqueeze(0)
|
75 |
+
.to(self.device)
|
76 |
+
)
|
77 |
+
encodec_frames = self.encodec.encode(encodec_audio)
|
78 |
+
encodec_frames = torch.cat(
|
79 |
+
[codebook for codebook, _ in encodec_frames], dim=-1
|
80 |
+
).mT
|
81 |
+
|
82 |
+
clap_audio = torchaudio.transforms.Resample(res, 48000)(audio).unsqueeze(0)
|
83 |
+
clap_embedding = self.clap_model.get_audio_embedding_from_data(clap_audio, use_tensor=True)
|
84 |
+
|
85 |
+
return self._infer(encodec_frames, clap_embedding, generation_config)
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def _infer(
|
89 |
+
self,
|
90 |
+
encodec_frames: torch.LongTensor,
|
91 |
+
clap_embedding: torch.Tensor,
|
92 |
+
generation_config: Dict[str, Any] = None,
|
93 |
+
) -> str:
|
94 |
+
input_ids = torch.cat(
|
95 |
+
[
|
96 |
+
torch.ones(
|
97 |
+
(encodec_frames.shape[0], 2, encodec_frames.shape[-1]),
|
98 |
+
dtype=torch.long,
|
99 |
+
).to(self.device)
|
100 |
+
* self.tokenizer.bos_token_id,
|
101 |
+
encodec_frames[:, : self.max_seq_len],
|
102 |
+
torch.ones(
|
103 |
+
(encodec_frames.shape[0], 1, encodec_frames.shape[-1]),
|
104 |
+
dtype=torch.long,
|
105 |
+
).to(self.device)
|
106 |
+
* self.tokenizer.eos_token_id,
|
107 |
+
],
|
108 |
+
dim=1,
|
109 |
+
)
|
110 |
+
encodec_mask = torch.LongTensor(
|
111 |
+
[[0, 0] + [1] * (input_ids.shape[1] - 3) + [0]]
|
112 |
+
).to(self.device)
|
113 |
+
|
114 |
+
enclap_bart_inputs = {
|
115 |
+
"input_ids": input_ids,
|
116 |
+
"encodec_mask": encodec_mask,
|
117 |
+
"clap_embedding": clap_embedding,
|
118 |
+
}
|
119 |
+
|
120 |
+
results = self.model.generate(**enclap_bart_inputs, **generation_config)
|
121 |
+
caption = self.tokenizer.batch_decode(results, skip_special_tokens=True)
|
122 |
+
|
123 |
+
return caption
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def infer_from_encodec(
|
127 |
+
self,
|
128 |
+
file_path,
|
129 |
+
clap_path: str = "clap",
|
130 |
+
generation_config: Dict[str, Any] = None,
|
131 |
+
):
|
132 |
+
if generation_config is None:
|
133 |
+
generation_config = self.generation_config
|
134 |
+
input_ids = np.load(file_path)
|
135 |
+
if input_ids.shape[0] > self.max_encodec_length:
|
136 |
+
input_ids = input_ids[: self.max_encodec_length, :]
|
137 |
+
input_length = input_ids.shape[0]
|
138 |
+
input_ids = np.concatenate([input_ids, self.eos_padding], axis=0)
|
139 |
+
input_ids = torch.LongTensor(input_ids)
|
140 |
+
input_ids = input_ids.unsqueeze(0).to(self.device)
|
141 |
+
attention_mask = (
|
142 |
+
torch.ones(input_length + 3, dtype=torch.int64).unsqueeze(0).to(self.device)
|
143 |
+
)
|
144 |
+
eos_mask = [0] * (input_length + 3)
|
145 |
+
eos_mask[input_length + 2] = 1
|
146 |
+
eos_mask = torch.BoolTensor(eos_mask).unsqueeze(0)
|
147 |
+
# Load CLAP
|
148 |
+
clap_path = file_path.replace("encodec_16", clap_path)
|
149 |
+
clap = np.load(clap_path)
|
150 |
+
clap = torch.Tensor(clap).unsqueeze(0).to(self.device)
|
151 |
+
input = {
|
152 |
+
"input_ids": input_ids,
|
153 |
+
"clap": clap,
|
154 |
+
"attention_mask": attention_mask,
|
155 |
+
"eos_mask": eos_mask,
|
156 |
+
}
|
157 |
+
|
158 |
+
generated_ids = self.model.generate(**input, **generation_config)
|
159 |
+
text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
160 |
+
|
161 |
+
return text
|
metric/__init__.py
ADDED
File without changes
|
metric/compute_metric.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from aac_metrics import evaluate
|
3 |
+
import copy
|
4 |
+
metric_list = ["bleu_1", "bleu_4", "rouge_l", "meteor", "spider_fl"]
|
5 |
+
|
6 |
+
if __name__=='__main__':
|
7 |
+
csv_path = "/workspace/audiobart/csv/predictions/prediction_clap.csv"
|
8 |
+
df = pd.read_csv(csv_path)
|
9 |
+
|
10 |
+
predictions = []
|
11 |
+
references = []
|
12 |
+
for idx in range(len(df)):
|
13 |
+
predictions.append(df.loc[idx]['prediction'])
|
14 |
+
reference = [df.loc[idx]['caption_1'],df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5'] ]
|
15 |
+
references.append(reference)
|
16 |
+
|
17 |
+
print("> Evaluating predictions...")
|
18 |
+
result = evaluate(predictions, references, metrics=metric_list)
|
19 |
+
result = {k: v.item() for k, v in result[0].items()}
|
20 |
+
keys = list(result.keys())
|
21 |
+
for key in keys:
|
22 |
+
if "fluerr" in key:
|
23 |
+
del result[key]
|
24 |
+
print(result)
|
metric/compute_metric_from_scratch.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('..')
|
3 |
+
sys.path.append('.')
|
4 |
+
|
5 |
+
from aac_metrics import evaluate
|
6 |
+
from inference import AudioBartInference
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
12 |
+
metric_list = ["bleu_1", "bleu_4", "rouge_l", "meteor", "spider_fl"]
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
dataset = "AudioCaps"
|
16 |
+
# dataset = "clotho"
|
17 |
+
ckpt_path = "/data/jyk/aac_results/bart_base/audiocaps_35e5_2000/checkpoints/epoch_8"
|
18 |
+
|
19 |
+
# ckpt_path = "/data/jyk/aac_results/masking/linear_scalinEg/checkpoints/epoch_14"
|
20 |
+
max_encodec_length = 1022
|
21 |
+
infer_module = AudioBartInference(ckpt_path, max_encodec_length)
|
22 |
+
from_encodec = True
|
23 |
+
csv_path = f"/workspace/audiobart/csv/{dataset}/test.csv"
|
24 |
+
base_path = f"/data/jyk/aac_dataset/{dataset}/encodec_16"
|
25 |
+
clap_name = "clap_audio_fused"
|
26 |
+
df = pd.read_csv(csv_path)
|
27 |
+
|
28 |
+
generation_config = {
|
29 |
+
"_from_model_config": True,
|
30 |
+
"bos_token_id": 0,
|
31 |
+
"decoder_start_token_id": 2,
|
32 |
+
"early_stopping": True,
|
33 |
+
"eos_token_id": 2,
|
34 |
+
"forced_bos_token_id": 0,
|
35 |
+
"forced_eos_token_id": 2,
|
36 |
+
"no_repeat_ngram_size": 3,
|
37 |
+
"num_beams": 4,
|
38 |
+
"pad_token_id": 1,
|
39 |
+
"max_length": 50
|
40 |
+
}
|
41 |
+
|
42 |
+
print(f"> Making Predictions for model {ckpt_path}...")
|
43 |
+
predictions = []
|
44 |
+
references = []
|
45 |
+
for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="BLUE"):
|
46 |
+
if not from_encodec:
|
47 |
+
wav_path = df.loc[idx]['file_name']
|
48 |
+
else:
|
49 |
+
wav_path = df.loc[idx]['file_path']
|
50 |
+
wav_path = os.path.join(base_path,wav_path)
|
51 |
+
if not os.path.exists(wav_path):
|
52 |
+
pass
|
53 |
+
|
54 |
+
if not from_encodec:
|
55 |
+
prediction = infer_module.infer(wav_path)
|
56 |
+
else:
|
57 |
+
prediction = infer_module.infer_from_encodec(wav_path, clap_name, generation_config)
|
58 |
+
|
59 |
+
predictions.append(prediction[0])
|
60 |
+
reference = [df.loc[idx]['caption_1'],df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5'] ]
|
61 |
+
references.append(reference)
|
62 |
+
|
63 |
+
print("> Evaluating predictions...")
|
64 |
+
result = evaluate(predictions, references, metrics=metric_list)
|
65 |
+
result = {k: round(v.item(),4) for k, v in result[0].items()}
|
66 |
+
keys = list(result.keys())
|
67 |
+
for key in keys:
|
68 |
+
if "fluerr" in key:
|
69 |
+
del result[key]
|
70 |
+
print(result)
|
metric/make_predictions.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('..')
|
3 |
+
|
4 |
+
from inference import AudioBartInference
|
5 |
+
from tqdm import tqdm
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
import csv
|
9 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
ckpt_path = "/data/jyk/aac_results/clap/clap/checkpoints/epoch_12"
|
14 |
+
infer_module = AudioBartInference(ckpt_path)
|
15 |
+
from_encodec = True
|
16 |
+
csv_path = "/workspace/audiobart/csv/test.csv"
|
17 |
+
base_path = "/data/jyk/aac_dataset/clotho/encodec"
|
18 |
+
df = pd.read_csv(csv_path)
|
19 |
+
save_path = "/workspace/audiobart/csv/predictions/prediction_clap.csv"
|
20 |
+
f = open(save_path, 'w', newline='')
|
21 |
+
writer = csv.writer(f)
|
22 |
+
writer.writerow(['file_path', 'prediction', 'caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5'])
|
23 |
+
|
24 |
+
print(f"> Making Predictions for model {ckpt_path}...")
|
25 |
+
for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="red"):
|
26 |
+
if not from_encodec:
|
27 |
+
wav_path = df.loc[idx]['file_name']
|
28 |
+
else:
|
29 |
+
wav_path = df.loc[idx]['file_path']
|
30 |
+
wav_path = os.path.join(base_path,wav_path)
|
31 |
+
if not os.path.exists(wav_path):
|
32 |
+
pass
|
33 |
+
|
34 |
+
if not from_encodec:
|
35 |
+
prediction = infer_module.infer(wav_path)
|
36 |
+
else:
|
37 |
+
prediction = infer_module.infer_from_encodec(wav_path)
|
38 |
+
line = [wav_path, prediction[0], df.loc[idx]['caption_1'], df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5']]
|
39 |
+
writer.writerow(line)
|
40 |
+
|
41 |
+
f.close()
|
modeling/__init__.py
ADDED
File without changes
|
modeling/enclap_bart.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
from transformers.modeling_outputs import (
|
9 |
+
BaseModelOutput,
|
10 |
+
Seq2SeqLMOutput,
|
11 |
+
Seq2SeqModelOutput,
|
12 |
+
)
|
13 |
+
from transformers.models.bart.configuration_bart import BartConfig
|
14 |
+
from transformers.models.bart.modeling_bart import (
|
15 |
+
BartDecoder,
|
16 |
+
BartEncoderLayer,
|
17 |
+
BartForConditionalGeneration,
|
18 |
+
BartLearnedPositionalEmbedding,
|
19 |
+
BartModel,
|
20 |
+
BartPretrainedModel,
|
21 |
+
_expand_mask,
|
22 |
+
shift_tokens_right,
|
23 |
+
)
|
24 |
+
from transformers.utils import logging
|
25 |
+
|
26 |
+
from .modeling_outputs import EnClapBartOutput
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class EnClapBartConfig(BartConfig):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_clap: int = 512,
|
35 |
+
num_rvq: int = 16,
|
36 |
+
encodec_vocab_size: int = 1024,
|
37 |
+
encodec_pad_token_id: int = 1024,
|
38 |
+
mcm_loss_scale: float = 0.7,
|
39 |
+
label_smoothing: float = 0.2,
|
40 |
+
**kwargs,
|
41 |
+
):
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
self.d_clap = d_clap
|
44 |
+
self.num_rvq = num_rvq
|
45 |
+
self.encodec_vocab_size = encodec_vocab_size
|
46 |
+
self.encodec_pad_token_id = encodec_pad_token_id
|
47 |
+
self.mcm_loss_scale = mcm_loss_scale
|
48 |
+
self.label_smoothing = label_smoothing
|
49 |
+
|
50 |
+
|
51 |
+
class EnClapBartEncoder(BartPretrainedModel):
|
52 |
+
"""
|
53 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
54 |
+
[`BartEncoderLayer`].
|
55 |
+
|
56 |
+
Args:
|
57 |
+
config: BartConfig
|
58 |
+
embed_tokens (nn.Embedding): output embedding
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self, config: EnClapBartConfig, embed_tokens: Optional[nn.Embedding] = None
|
63 |
+
):
|
64 |
+
super().__init__(config)
|
65 |
+
|
66 |
+
self.dropout = config.dropout
|
67 |
+
self.layerdrop = config.encoder_layerdrop
|
68 |
+
|
69 |
+
clap_dim = config.d_clap
|
70 |
+
embed_dim = config.d_model
|
71 |
+
self.padding_idx = config.pad_token_id
|
72 |
+
self.max_source_positions = config.max_position_embeddings
|
73 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
74 |
+
|
75 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
76 |
+
|
77 |
+
if embed_tokens is not None:
|
78 |
+
self.embed_tokens.weight = embed_tokens.weight
|
79 |
+
|
80 |
+
self.embed_encodec = nn.ModuleList(
|
81 |
+
[
|
82 |
+
nn.Embedding(
|
83 |
+
math.ceil((config.encodec_vocab_size + 1) / 64) * 64,
|
84 |
+
config.d_model,
|
85 |
+
padding_idx=config.encodec_pad_token_id,
|
86 |
+
)
|
87 |
+
for _ in range(config.num_rvq)
|
88 |
+
]
|
89 |
+
)
|
90 |
+
|
91 |
+
self.clap_projection = nn.Linear(clap_dim, embed_dim)
|
92 |
+
|
93 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
94 |
+
config.max_position_embeddings,
|
95 |
+
embed_dim,
|
96 |
+
)
|
97 |
+
self.layers = nn.ModuleList(
|
98 |
+
[BartEncoderLayer(config) for _ in range(config.encoder_layers)]
|
99 |
+
)
|
100 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
101 |
+
|
102 |
+
self.gradient_checkpointing = False
|
103 |
+
# Initialize weights and apply final processing
|
104 |
+
self.post_init()
|
105 |
+
|
106 |
+
def get_input_embeddings(self):
|
107 |
+
return self.embed_tokens
|
108 |
+
|
109 |
+
def set_input_embeddings(self, value):
|
110 |
+
self.embed_tokens = value
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
input_ids: torch.LongTensor = None,
|
115 |
+
clap_embedding: Optional[torch.Tensor] = None,
|
116 |
+
encodec_mask: Optional[torch.Tensor] = None,
|
117 |
+
attention_mask: Optional[torch.Tensor] = None,
|
118 |
+
head_mask: Optional[torch.Tensor] = None,
|
119 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
120 |
+
output_attentions: Optional[bool] = None,
|
121 |
+
output_hidden_states: Optional[bool] = None,
|
122 |
+
return_dict: Optional[bool] = None,
|
123 |
+
) -> Union[Tuple, BaseModelOutput]:
|
124 |
+
r"""
|
125 |
+
Args:
|
126 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
127 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
128 |
+
provide it.
|
129 |
+
|
130 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
131 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
132 |
+
|
133 |
+
[What are input IDs?](../glossary#input-ids)
|
134 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
135 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
136 |
+
|
137 |
+
- 1 for tokens that are **not masked**,
|
138 |
+
- 0 for tokens that are **masked**.
|
139 |
+
|
140 |
+
[What are attention masks?](../glossary#attention-mask)
|
141 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
142 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
143 |
+
|
144 |
+
- 1 indicates the head is **not masked**,
|
145 |
+
- 0 indicates the head is **masked**.
|
146 |
+
|
147 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
148 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
149 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
150 |
+
than the model's internal embedding lookup matrix.
|
151 |
+
output_attentions (`bool`, *optional*):
|
152 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
153 |
+
returned tensors for more detail.
|
154 |
+
output_hidden_states (`bool`, *optional*):
|
155 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
156 |
+
for more detail.
|
157 |
+
return_dict (`bool`, *optional*):
|
158 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
159 |
+
"""
|
160 |
+
output_attentions = (
|
161 |
+
output_attentions
|
162 |
+
if output_attentions is not None
|
163 |
+
else self.config.output_attentions
|
164 |
+
)
|
165 |
+
output_hidden_states = (
|
166 |
+
output_hidden_states
|
167 |
+
if output_hidden_states is not None
|
168 |
+
else self.config.output_hidden_states
|
169 |
+
)
|
170 |
+
return_dict = (
|
171 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
172 |
+
)
|
173 |
+
|
174 |
+
# retrieve input_ids and inputs_embeds
|
175 |
+
if input_ids is not None and inputs_embeds is not None:
|
176 |
+
raise ValueError(
|
177 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
178 |
+
)
|
179 |
+
elif input_ids is not None:
|
180 |
+
if input_ids.ndim == 2: # This is effectively just input = input_ids
|
181 |
+
input = input_ids
|
182 |
+
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
183 |
+
elif inputs_embeds is not None:
|
184 |
+
input = inputs_embeds[:, :, -1]
|
185 |
+
else:
|
186 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
187 |
+
|
188 |
+
if inputs_embeds is None:
|
189 |
+
if input_ids.ndim == 2:
|
190 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
191 |
+
elif input_ids.ndim == 3:
|
192 |
+
encodec_ids = torch.where(encodec_mask.unsqueeze(-1) > 0, input_ids, 0)
|
193 |
+
encodec_embeds = torch.zeros(
|
194 |
+
input_ids.shape[0], input_ids.shape[1], self.config.d_model
|
195 |
+
).to(self.device)
|
196 |
+
for i, embed in enumerate(self.embed_encodec):
|
197 |
+
encodec_embeds = encodec_embeds + embed(encodec_ids[..., i])
|
198 |
+
bart_ids = torch.where(encodec_mask == 0, input_ids[..., 0], 0)
|
199 |
+
bart_embeds = self.embed_tokens(bart_ids)
|
200 |
+
input_embeds = torch.where(
|
201 |
+
encodec_mask.unsqueeze(-1) > 0, encodec_embeds, bart_embeds
|
202 |
+
)
|
203 |
+
|
204 |
+
# Get CLAP embedding
|
205 |
+
if clap_embedding is not None:
|
206 |
+
clap_embedding = self.clap_projection(clap_embedding)
|
207 |
+
input_embeds[:, 0] = clap_embedding
|
208 |
+
inputs_embeds = input_embeds.to(self.device)
|
209 |
+
|
210 |
+
batch_size = input_ids.size(0)
|
211 |
+
embed_pos = self.embed_positions(input_ids).to(self.device)
|
212 |
+
embed_pos = torch.cat(
|
213 |
+
[
|
214 |
+
torch.zeros(batch_size, 1, self.config.d_model).to(self.device),
|
215 |
+
embed_pos[:, :-1],
|
216 |
+
],
|
217 |
+
dim=1,
|
218 |
+
)
|
219 |
+
|
220 |
+
hidden_states = inputs_embeds + embed_pos
|
221 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
222 |
+
hidden_states = nn.functional.dropout(
|
223 |
+
hidden_states, p=self.dropout, training=self.training
|
224 |
+
)
|
225 |
+
|
226 |
+
# expand attention_mask
|
227 |
+
if attention_mask is not None:
|
228 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
229 |
+
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
230 |
+
|
231 |
+
encoder_states = () if output_hidden_states else None
|
232 |
+
all_attentions = () if output_attentions else None
|
233 |
+
|
234 |
+
# check if head_mask has a correct number of layers specified if desired
|
235 |
+
if head_mask is not None:
|
236 |
+
if head_mask.size()[0] != (len(self.layers)):
|
237 |
+
raise ValueError(
|
238 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
239 |
+
f" {head_mask.size()[0]}."
|
240 |
+
)
|
241 |
+
|
242 |
+
for idx, encoder_layer in enumerate(self.layers):
|
243 |
+
if output_hidden_states:
|
244 |
+
encoder_states = encoder_states + (hidden_states,)
|
245 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
246 |
+
dropout_probability = random.uniform(0, 1)
|
247 |
+
if self.training and (
|
248 |
+
dropout_probability < self.layerdrop
|
249 |
+
): # skip the layer
|
250 |
+
layer_outputs = (None, None)
|
251 |
+
else:
|
252 |
+
if self.gradient_checkpointing and self.training:
|
253 |
+
|
254 |
+
def create_custom_forward(module):
|
255 |
+
def custom_forward(*inputs):
|
256 |
+
return module(*inputs, output_attentions)
|
257 |
+
|
258 |
+
return custom_forward
|
259 |
+
|
260 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
261 |
+
create_custom_forward(encoder_layer),
|
262 |
+
hidden_states,
|
263 |
+
attention_mask,
|
264 |
+
(head_mask[idx] if head_mask is not None else None),
|
265 |
+
)
|
266 |
+
else:
|
267 |
+
layer_outputs = encoder_layer(
|
268 |
+
hidden_states,
|
269 |
+
attention_mask,
|
270 |
+
layer_head_mask=(
|
271 |
+
head_mask[idx] if head_mask is not None else None
|
272 |
+
),
|
273 |
+
output_attentions=output_attentions,
|
274 |
+
)
|
275 |
+
|
276 |
+
hidden_states = layer_outputs[0]
|
277 |
+
|
278 |
+
if output_attentions:
|
279 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
280 |
+
|
281 |
+
if output_hidden_states:
|
282 |
+
encoder_states = encoder_states + (hidden_states,)
|
283 |
+
|
284 |
+
if not return_dict:
|
285 |
+
return tuple(
|
286 |
+
v
|
287 |
+
for v in [hidden_states, encoder_states, all_attentions]
|
288 |
+
if v is not None
|
289 |
+
)
|
290 |
+
return BaseModelOutput(
|
291 |
+
last_hidden_state=hidden_states,
|
292 |
+
hidden_states=encoder_states,
|
293 |
+
attentions=all_attentions,
|
294 |
+
)
|
295 |
+
|
296 |
+
|
297 |
+
class EnClapBartModel(BartModel):
|
298 |
+
def __init__(self, config: EnClapBartConfig):
|
299 |
+
super(BartModel, self).__init__(config)
|
300 |
+
|
301 |
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
302 |
+
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
303 |
+
|
304 |
+
self.encoder = EnClapBartEncoder(config, self.shared)
|
305 |
+
self.decoder = BartDecoder(config, self.shared)
|
306 |
+
|
307 |
+
# Initialize weights and apply final processing
|
308 |
+
self.post_init()
|
309 |
+
|
310 |
+
def forward(
|
311 |
+
self,
|
312 |
+
input_ids: torch.LongTensor = None,
|
313 |
+
clap_embedding: Optional[torch.Tensor] = None,
|
314 |
+
encodec_mask: Optional[torch.Tensor] = None,
|
315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
316 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
317 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
318 |
+
head_mask: Optional[torch.Tensor] = None,
|
319 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
320 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
321 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
322 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
323 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
324 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
325 |
+
use_cache: Optional[bool] = None,
|
326 |
+
output_attentions: Optional[bool] = None,
|
327 |
+
output_hidden_states: Optional[bool] = None,
|
328 |
+
return_dict: Optional[bool] = None,
|
329 |
+
) -> Union[Tuple, Seq2SeqModelOutput]:
|
330 |
+
# different to other models, Bart automatically creates decoder_input_ids from
|
331 |
+
# input_ids if no decoder_input_ids are provided
|
332 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
333 |
+
if input_ids is None:
|
334 |
+
raise ValueError(
|
335 |
+
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
336 |
+
"passed, `input_ids` cannot be `None`. Please pass either "
|
337 |
+
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
338 |
+
)
|
339 |
+
|
340 |
+
decoder_input_ids = shift_tokens_right(
|
341 |
+
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
|
342 |
+
)
|
343 |
+
|
344 |
+
output_attentions = (
|
345 |
+
output_attentions
|
346 |
+
if output_attentions is not None
|
347 |
+
else self.config.output_attentions
|
348 |
+
)
|
349 |
+
output_hidden_states = (
|
350 |
+
output_hidden_states
|
351 |
+
if output_hidden_states is not None
|
352 |
+
else self.config.output_hidden_states
|
353 |
+
)
|
354 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
355 |
+
return_dict = (
|
356 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
357 |
+
)
|
358 |
+
|
359 |
+
if encoder_outputs is None:
|
360 |
+
encoder_outputs = self.encoder(
|
361 |
+
input_ids=input_ids,
|
362 |
+
clap_embedding=clap_embedding,
|
363 |
+
encodec_mask=encodec_mask,
|
364 |
+
attention_mask=attention_mask,
|
365 |
+
head_mask=head_mask,
|
366 |
+
inputs_embeds=inputs_embeds,
|
367 |
+
output_attentions=output_attentions,
|
368 |
+
output_hidden_states=output_hidden_states,
|
369 |
+
return_dict=return_dict,
|
370 |
+
)
|
371 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
372 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
373 |
+
encoder_outputs = BaseModelOutput(
|
374 |
+
last_hidden_state=encoder_outputs[0],
|
375 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
376 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
377 |
+
)
|
378 |
+
|
379 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
380 |
+
decoder_outputs = self.decoder(
|
381 |
+
input_ids=decoder_input_ids,
|
382 |
+
attention_mask=decoder_attention_mask,
|
383 |
+
encoder_hidden_states=encoder_outputs[0],
|
384 |
+
encoder_attention_mask=attention_mask,
|
385 |
+
head_mask=decoder_head_mask,
|
386 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
387 |
+
past_key_values=past_key_values,
|
388 |
+
inputs_embeds=decoder_inputs_embeds,
|
389 |
+
use_cache=use_cache,
|
390 |
+
output_attentions=output_attentions,
|
391 |
+
output_hidden_states=output_hidden_states,
|
392 |
+
return_dict=return_dict,
|
393 |
+
)
|
394 |
+
|
395 |
+
if not return_dict:
|
396 |
+
return decoder_outputs + encoder_outputs
|
397 |
+
|
398 |
+
return Seq2SeqModelOutput(
|
399 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
400 |
+
past_key_values=decoder_outputs.past_key_values,
|
401 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
402 |
+
decoder_attentions=decoder_outputs.attentions,
|
403 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
404 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
405 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
406 |
+
encoder_attentions=encoder_outputs.attentions,
|
407 |
+
)
|
408 |
+
|
409 |
+
|
410 |
+
class EnClapBartForConditionalGeneration(BartForConditionalGeneration):
|
411 |
+
config_class = EnClapBartConfig
|
412 |
+
|
413 |
+
def __init__(self, config: EnClapBartConfig):
|
414 |
+
super(BartForConditionalGeneration, self).__init__(config)
|
415 |
+
self.model = EnClapBartModel(config)
|
416 |
+
self.register_buffer(
|
417 |
+
"final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
|
418 |
+
)
|
419 |
+
self.lm_head = nn.Linear(
|
420 |
+
config.d_model, self.model.shared.num_embeddings, bias=False
|
421 |
+
)
|
422 |
+
self.mcm_heads = nn.ModuleList(
|
423 |
+
[
|
424 |
+
nn.Linear(config.d_model, config.encodec_vocab_size)
|
425 |
+
for _ in range(config.num_rvq)
|
426 |
+
]
|
427 |
+
)
|
428 |
+
|
429 |
+
# Initialize weights and apply final processing
|
430 |
+
self.post_init()
|
431 |
+
|
432 |
+
def forward(
|
433 |
+
self,
|
434 |
+
input_ids: torch.LongTensor = None,
|
435 |
+
clap_embedding: Optional[torch.Tensor] = None,
|
436 |
+
encodec_mask: Optional[torch.Tensor] = None,
|
437 |
+
attention_mask: Optional[torch.Tensor] = None,
|
438 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
439 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
440 |
+
head_mask: Optional[torch.Tensor] = None,
|
441 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
442 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
443 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
444 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
445 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
446 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
447 |
+
labels: Optional[torch.LongTensor] = None,
|
448 |
+
mcm_labels: Optional[List[torch.LongTensor]] = None,
|
449 |
+
use_cache: Optional[bool] = None,
|
450 |
+
output_attentions: Optional[bool] = None,
|
451 |
+
output_hidden_states: Optional[bool] = None,
|
452 |
+
return_dict: Optional[bool] = None,
|
453 |
+
) -> Union[Tuple, Seq2SeqLMOutput]:
|
454 |
+
r"""
|
455 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
456 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
457 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
458 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
459 |
+
|
460 |
+
Returns:
|
461 |
+
"""
|
462 |
+
return_dict = (
|
463 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
464 |
+
)
|
465 |
+
|
466 |
+
if labels is not None:
|
467 |
+
if use_cache:
|
468 |
+
logger.warning(
|
469 |
+
"The `use_cache` argument is changed to `False` since `labels` is provided."
|
470 |
+
)
|
471 |
+
use_cache = False
|
472 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
473 |
+
decoder_input_ids = shift_tokens_right(
|
474 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
475 |
+
)
|
476 |
+
|
477 |
+
outputs = self.model(
|
478 |
+
input_ids,
|
479 |
+
clap_embedding=clap_embedding,
|
480 |
+
encodec_mask=encodec_mask,
|
481 |
+
attention_mask=attention_mask,
|
482 |
+
decoder_input_ids=decoder_input_ids,
|
483 |
+
encoder_outputs=encoder_outputs,
|
484 |
+
decoder_attention_mask=decoder_attention_mask,
|
485 |
+
head_mask=head_mask,
|
486 |
+
decoder_head_mask=decoder_head_mask,
|
487 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
488 |
+
past_key_values=past_key_values,
|
489 |
+
inputs_embeds=inputs_embeds,
|
490 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
491 |
+
use_cache=use_cache,
|
492 |
+
output_attentions=output_attentions,
|
493 |
+
output_hidden_states=output_hidden_states,
|
494 |
+
return_dict=return_dict,
|
495 |
+
)
|
496 |
+
|
497 |
+
mcm_loss = None
|
498 |
+
if mcm_labels is not None:
|
499 |
+
mcm_loss = 0.0
|
500 |
+
loss_fct = CrossEntropyLoss()
|
501 |
+
for i, mcm_head in enumerate(self.mcm_heads):
|
502 |
+
mcm_logits = mcm_head(outputs.encoder_last_hidden_state)
|
503 |
+
loss_scale = 1 / 2 ** (i + 1)
|
504 |
+
loss = loss_fct(
|
505 |
+
mcm_logits.view(-1, self.config.encodec_vocab_size),
|
506 |
+
mcm_labels[..., i].reshape(-1),
|
507 |
+
)
|
508 |
+
mcm_loss = mcm_loss + loss * loss_scale
|
509 |
+
|
510 |
+
lm_logits = self.lm_head(outputs[0])
|
511 |
+
lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
|
512 |
+
|
513 |
+
masked_lm_loss = None
|
514 |
+
if labels is not None:
|
515 |
+
labels = labels.to(lm_logits.device)
|
516 |
+
loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
|
517 |
+
masked_lm_loss = loss_fct(
|
518 |
+
lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
|
519 |
+
)
|
520 |
+
|
521 |
+
loss = None
|
522 |
+
if mcm_loss is None:
|
523 |
+
loss = masked_lm_loss
|
524 |
+
elif masked_lm_loss is None:
|
525 |
+
loss = mcm_loss
|
526 |
+
else:
|
527 |
+
mcm_loss = mcm_loss * self.config.mcm_loss_scale
|
528 |
+
loss = masked_lm_loss + mcm_loss
|
529 |
+
|
530 |
+
if not return_dict:
|
531 |
+
output = (lm_logits,) + outputs[1:]
|
532 |
+
return (
|
533 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
534 |
+
)
|
535 |
+
|
536 |
+
return EnClapBartOutput(
|
537 |
+
loss=loss,
|
538 |
+
lm_loss=masked_lm_loss,
|
539 |
+
mcm_loss=mcm_loss,
|
540 |
+
logits=lm_logits,
|
541 |
+
past_key_values=outputs.past_key_values,
|
542 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
543 |
+
decoder_attentions=outputs.decoder_attentions,
|
544 |
+
cross_attentions=outputs.cross_attentions,
|
545 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
546 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
547 |
+
encoder_attentions=outputs.encoder_attentions,
|
548 |
+
)
|
modeling/modeling_outputs.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class EnClapBartOutput(Seq2SeqLMOutput):
|
10 |
+
mcm_loss: Optional[torch.FloatTensor] = None
|
11 |
+
lm_loss: Optional[torch.FloatTensor] = None
|
port_weights.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from math import ceil
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--ckpt_path", "-c", type=str)
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
weight_name_map = {
|
13 |
+
"model.encodec_embeddings": None,
|
14 |
+
"encodec_embeddings": "embed_encodec",
|
15 |
+
"encodec_mlm_head": "mcm_heads",
|
16 |
+
}
|
17 |
+
|
18 |
+
ckpt_path = Path(args.ckpt_path)
|
19 |
+
weight_file = ckpt_path / "pytorch_model.bin"
|
20 |
+
state_dict = torch.load(weight_file, map_location="cpu")
|
21 |
+
new_state_dict = {}
|
22 |
+
for key in state_dict:
|
23 |
+
new_key = key
|
24 |
+
for orig, repl in weight_name_map.items():
|
25 |
+
if repl is None:
|
26 |
+
if orig in new_key:
|
27 |
+
new_key = None
|
28 |
+
break
|
29 |
+
continue
|
30 |
+
new_key = new_key.replace(orig, repl)
|
31 |
+
if new_key:
|
32 |
+
new_state_dict[new_key] = state_dict[key]
|
33 |
+
for key in new_state_dict:
|
34 |
+
if "model.encoder.embed_encodec" in key:
|
35 |
+
dim = new_state_dict[key].shape[0]
|
36 |
+
new_weight = torch.normal(
|
37 |
+
0, 1, (ceil(dim / 64) * 64, new_state_dict[key].shape[1])
|
38 |
+
)
|
39 |
+
new_weight[:dim] = new_state_dict[key]
|
40 |
+
new_state_dict[key] = new_weight
|
41 |
+
weight_file.rename(weight_file.with_suffix(".bin.bak"))
|
42 |
+
torch.save(new_state_dict, weight_file)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aac-metrics==0.4.2
|
2 |
+
accelerate==0.20.3
|
3 |
+
datasets==2.13.1
|
4 |
+
encodec==0.1.1
|
5 |
+
laion-clap==1.1.4
|
6 |
+
librosa==0.10.1
|
7 |
+
markupsafe==2.0.1
|
8 |
+
omegaconf==2.3.0
|
9 |
+
soundfile==0.12.1
|
10 |
+
tensorboard==2.13.0
|
11 |
+
tokenizers==0.13.3
|
12 |
+
torch==1.13.0
|
13 |
+
torchaudio==0.13.0
|
14 |
+
transformers==4.29.0
|
test/bart_test.ipynb
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from transformers import AutoTokenizer\n",
|
19 |
+
"from transformers.models.bart.modeling_bart import BartForConditionalGeneration"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 2,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large\")"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 35,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"data": {
|
38 |
+
"text/plain": [
|
39 |
+
"('bart/tokenizer/tokenizer_config.json',\n",
|
40 |
+
" 'bart/tokenizer/special_tokens_map.json',\n",
|
41 |
+
" 'bart/tokenizer/vocab.json',\n",
|
42 |
+
" 'bart/tokenizer/merges.txt',\n",
|
43 |
+
" 'bart/tokenizer/added_tokens.json',\n",
|
44 |
+
" 'bart/tokenizer/tokenizer.json')"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
"execution_count": 35,
|
48 |
+
"metadata": {},
|
49 |
+
"output_type": "execute_result"
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"source": [
|
53 |
+
"tokenizer.save_pretrained(\"bart/tokenizer\")"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 18,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"model = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large\", forced_bos_token_id=0)"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": 4,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"name": "stderr",
|
72 |
+
"output_type": "stream",
|
73 |
+
"text": [
|
74 |
+
"Some weights of AudioBartForConditionalGeneration were not initialized from the model checkpoint at bart/model/ and are newly initialized: ['model.encodec_embeddings.3.weight', 'model.encodec_embeddings.4.weight', 'model.encodec_embeddings.1.weight', 'model.encodec_embeddings.0.weight', 'model.encoder.encodec_embeddings.7.weight', 'model.encodec_embeddings.2.weight', 'model.encodec_embeddings.6.weight', 'model.encoder.encodec_embeddings.0.weight', 'model.encodec_embeddings.7.weight', 'model.encoder.encodec_embeddings.4.weight', 'model.encoder.encodec_embeddings.2.weight', 'model.encoder.encodec_embeddings.3.weight', 'model.encodec_embeddings.5.weight', 'model.encoder.encodec_embeddings.5.weight', 'model.encoder.encodec_embeddings.1.weight', 'model.encoder.encodec_embeddings.6.weight']\n",
|
75 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
76 |
+
]
|
77 |
+
}
|
78 |
+
],
|
79 |
+
"source": [
|
80 |
+
"from modeling.audiobart import AudioBartForConditionalGeneration\n",
|
81 |
+
"model = AudioBartForConditionalGeneration.from_pretrained(\"bart/model/\")"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": 5,
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [
|
89 |
+
{
|
90 |
+
"name": "stdout",
|
91 |
+
"output_type": "stream",
|
92 |
+
"text": [
|
93 |
+
"{'input_ids': tensor([[ 0, 31414, 127, 50264, 32440, 3807, 118, 32440, 3807, 118,\n",
|
94 |
+
" 25610, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n"
|
95 |
+
]
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"text = \"Hello my <mask> yeppi yeppi yo\"\n",
|
100 |
+
"input = tokenizer(text, return_tensors='pt')\n",
|
101 |
+
"print(input)"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 8,
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"generated_ids = model.generate(input[\"input_ids\"])"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": 33,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"ids = output.logits.detach().numpy().argmax(-1)"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 11,
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [
|
127 |
+
{
|
128 |
+
"data": {
|
129 |
+
"text/plain": [
|
130 |
+
"['Hello my friends, yeppi yeppiiyeppiyeppii ye']"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
"execution_count": 11,
|
134 |
+
"metadata": {},
|
135 |
+
"output_type": "execute_result"
|
136 |
+
}
|
137 |
+
],
|
138 |
+
"source": [
|
139 |
+
"tokenizer.batch_decode(generated_ids, skip_special_tokens=True)"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 36,
|
145 |
+
"metadata": {},
|
146 |
+
"outputs": [],
|
147 |
+
"source": [
|
148 |
+
"model.save_pretrained(\"bart/model\")"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": 51,
|
154 |
+
"metadata": {},
|
155 |
+
"outputs": [
|
156 |
+
{
|
157 |
+
"name": "stdout",
|
158 |
+
"output_type": "stream",
|
159 |
+
"text": [
|
160 |
+
"{'input_ids': tensor([[ 0, 7842, 330, 506, 1536, 267, 131, 6634, 36807, 571,\n",
|
161 |
+
" 20920, 127, 766, 16, 32440, 3807, 118, 32440, 3807, 118,\n",
|
162 |
+
" 25610, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n"
|
163 |
+
]
|
164 |
+
}
|
165 |
+
],
|
166 |
+
"source": [
|
167 |
+
"text = \"adskfalsj;lsdfg Hello my name is yeppi yeppi yo\"\n",
|
168 |
+
"input = tokenizer(text, return_tensors='pt')\n",
|
169 |
+
"print(input)"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": 52,
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"output = model.forward(**input)"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 54,
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [
|
186 |
+
{
|
187 |
+
"name": "stderr",
|
188 |
+
"output_type": "stream",
|
189 |
+
"text": [
|
190 |
+
"/opt/conda/lib/python3.9/site-packages/transformers/generation/utils.py:1353: UserWarning: Using `max_length`'s default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
|
191 |
+
" warnings.warn(\n"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"data": {
|
196 |
+
"text/plain": [
|
197 |
+
"['</s><s>adskfalsj;lsdfg Hello my name is yeppi ye</s>']"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
"execution_count": 54,
|
201 |
+
"metadata": {},
|
202 |
+
"output_type": "execute_result"
|
203 |
+
}
|
204 |
+
],
|
205 |
+
"source": [
|
206 |
+
"tokenizer.batch_decode(model.generate(input['input_ids']))"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": 39,
|
212 |
+
"metadata": {},
|
213 |
+
"outputs": [
|
214 |
+
{
|
215 |
+
"data": {
|
216 |
+
"text/plain": [
|
217 |
+
"['<s>Hello my name is yeppi yeppi yo</s>']"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
"execution_count": 39,
|
221 |
+
"metadata": {},
|
222 |
+
"output_type": "execute_result"
|
223 |
+
}
|
224 |
+
],
|
225 |
+
"source": [
|
226 |
+
"tokenizer.batch_decode(input['input_ids'])"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": 45,
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"from transformers.models.bart.modeling_bart import shift_tokens_right"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 48,
|
241 |
+
"metadata": {},
|
242 |
+
"outputs": [
|
243 |
+
{
|
244 |
+
"name": "stdout",
|
245 |
+
"output_type": "stream",
|
246 |
+
"text": [
|
247 |
+
"tensor([[ 0, 0, 31414, 127, 766, 16, 32440, 3807, 118, 32440,\n",
|
248 |
+
" 3807, 118, 25610]])\n"
|
249 |
+
]
|
250 |
+
}
|
251 |
+
],
|
252 |
+
"source": [
|
253 |
+
"print(shift_tokens_right(input['input_ids'], pad_token_id=1, decoder_start_token_id=0))"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"execution_count": 55,
|
259 |
+
"metadata": {},
|
260 |
+
"outputs": [
|
261 |
+
{
|
262 |
+
"name": "stderr",
|
263 |
+
"output_type": "stream",
|
264 |
+
"text": [
|
265 |
+
"Downloading (…)lve/main/config.json: 100%|██████████| 1.58k/1.58k [00:00<00:00, 589kB/s]\n",
|
266 |
+
"Downloading (…)olve/main/vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 1.29MB/s]\n",
|
267 |
+
"Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 884kB/s]\n",
|
268 |
+
"Downloading (…)/main/tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 7.43MB/s]\n"
|
269 |
+
]
|
270 |
+
}
|
271 |
+
],
|
272 |
+
"source": [
|
273 |
+
"cnn_tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large-cnn\")"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "code",
|
278 |
+
"execution_count": 63,
|
279 |
+
"metadata": {},
|
280 |
+
"outputs": [],
|
281 |
+
"source": [
|
282 |
+
"original_text = \"ArithmeticErrorThe tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.\""
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": 68,
|
288 |
+
"metadata": {},
|
289 |
+
"outputs": [
|
290 |
+
{
|
291 |
+
"name": "stdout",
|
292 |
+
"output_type": "stream",
|
293 |
+
"text": [
|
294 |
+
"['<s>ArithmeticErrorThe tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.</s>']\n"
|
295 |
+
]
|
296 |
+
}
|
297 |
+
],
|
298 |
+
"source": [
|
299 |
+
"input = cnn_tokenizer(text=original_text, return_tensors='pt')\n",
|
300 |
+
"print(cnn_tokenizer.batch_decode(input['input_ids']))"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"execution_count": 65,
|
306 |
+
"metadata": {},
|
307 |
+
"outputs": [],
|
308 |
+
"source": [
|
309 |
+
"cnn_model = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")"
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "code",
|
314 |
+
"execution_count": 66,
|
315 |
+
"metadata": {},
|
316 |
+
"outputs": [
|
317 |
+
{
|
318 |
+
"name": "stderr",
|
319 |
+
"output_type": "stream",
|
320 |
+
"text": [
|
321 |
+
"/opt/conda/lib/python3.9/site-packages/transformers/generation/utils.py:1353: UserWarning: Using `max_length`'s default (142) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
|
322 |
+
" warnings.warn(\n"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"data": {
|
327 |
+
"text/plain": [
|
328 |
+
"['</s><s>The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world. It was the first structure to reach a height of 300 metres.</s>']"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
"execution_count": 66,
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "execute_result"
|
334 |
+
}
|
335 |
+
],
|
336 |
+
"source": [
|
337 |
+
"cnn_tokenizer.batch_decode(cnn_model.generate(input['input_ids']))"
|
338 |
+
]
|
339 |
+
}
|
340 |
+
],
|
341 |
+
"metadata": {
|
342 |
+
"kernelspec": {
|
343 |
+
"display_name": "base",
|
344 |
+
"language": "python",
|
345 |
+
"name": "python3"
|
346 |
+
},
|
347 |
+
"language_info": {
|
348 |
+
"codemirror_mode": {
|
349 |
+
"name": "ipython",
|
350 |
+
"version": 3
|
351 |
+
},
|
352 |
+
"file_extension": ".py",
|
353 |
+
"mimetype": "text/x-python",
|
354 |
+
"name": "python",
|
355 |
+
"nbconvert_exporter": "python",
|
356 |
+
"pygments_lexer": "ipython3",
|
357 |
+
"version": "3.9.12"
|
358 |
+
},
|
359 |
+
"orig_nbformat": 4
|
360 |
+
},
|
361 |
+
"nbformat": 4,
|
362 |
+
"nbformat_minor": 2
|
363 |
+
}
|
test/clap_test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
test/dataset_test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
test/dataset_test.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
sys.path.append("..")
|
4 |
+
|
5 |
+
from datasets import load_dataset
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
from modeling.audiobart import AudioBartForConditionalGeneration
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from data.collator import EncodecCollator
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import os
|
14 |
+
|
15 |
+
if __name__=="__main__":
|
16 |
+
model = AudioBartForConditionalGeneration.from_pretrained('bart/model')
|
17 |
+
base_path = "/data/jyk/aac_dataset/AudioCaps/encodec_16/"
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')
|
19 |
+
data_files = {"train": "csv/AudioCaps/train.csv"}
|
20 |
+
max_encodec_length = 1021
|
21 |
+
clap_base_path = "/data/jyk/aac_dataset/AudioCaps/clap"
|
22 |
+
|
23 |
+
raw_dataset = load_dataset("csv", data_files=data_files)
|
24 |
+
|
25 |
+
def preprocess_function(example):
|
26 |
+
path = example['file_path']
|
27 |
+
encodec = np.load(os.path.join(base_path, path))
|
28 |
+
if encodec.shape[0]>max_encodec_length:
|
29 |
+
encodec = encodec[:max_encodec_length, :]
|
30 |
+
clap = np.load(os.path.join(clap_base_path, path))
|
31 |
+
attention_mask = np.ones(encodec.shape[0]+3).astype(np.int64)
|
32 |
+
target_text = tokenizer(text_target=example['caption'])
|
33 |
+
|
34 |
+
return {'input_ids': encodec, 'clap': clap, 'attention_mask': attention_mask, 'labels': target_text['input_ids'], 'decoder_attention_mask': target_text['attention_mask']}
|
35 |
+
|
36 |
+
train_dataset = raw_dataset['train'].map(preprocess_function)
|
37 |
+
train_dataset.set_format("pt", columns=['input_ids', 'attention_mask', 'clap', 'labels', 'decoder_attention_mask'])
|
38 |
+
|
39 |
+
train_data_collator = EncodecCollator(
|
40 |
+
tokenizer=tokenizer,
|
41 |
+
model=model,
|
42 |
+
return_tensors="pt",
|
43 |
+
random_sampling=False,
|
44 |
+
max_length=max_encodec_length,
|
45 |
+
num_subsampling=0,
|
46 |
+
clap_masking_prob=-1,
|
47 |
+
encodec_masking_prob=0.15,
|
48 |
+
encodec_masking_length=10
|
49 |
+
)
|
50 |
+
|
51 |
+
train_dataloader = DataLoader(
|
52 |
+
train_dataset, shuffle=True, collate_fn=train_data_collator, batch_size=16)
|
53 |
+
|
54 |
+
for idx, batch in enumerate(train_dataloader):
|
55 |
+
# output = model.generate(**batch, max_length=100)
|
56 |
+
output = model(**batch)
|
57 |
+
print(output)
|
test/encodec_test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
test/encodec_test.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from encodec import EncodecModel
|
2 |
+
from encodec.utils import convert_audio
|
3 |
+
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Instantiate a pretrained EnCodec model
|
8 |
+
model = EncodecModel.encodec_model_24khz()
|
9 |
+
# The number of codebooks used will be determined bythe bandwidth selected.
|
10 |
+
# E.g. for a bandwidth of 6kbps, `n_q = 8` codebooks are used.
|
11 |
+
# Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8) and 12 kbps (n_q =16) and 24kbps (n_q=32).
|
12 |
+
# For the 48 kHz model, only 3, 6, 12, and 24 kbps are supported. The number
|
13 |
+
# of codebooks for each is half that of the 24 kHz model as the frame rate is twice as much.
|
14 |
+
model.set_target_bandwidth(6.0)
|
15 |
+
|
16 |
+
# Load and pre-process the audio waveform
|
17 |
+
wav, sr = torchaudio.load("<PATH_TO_AUDIO_FILE>")
|
18 |
+
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
|
19 |
+
wav = wav.unsqueeze(0)
|
20 |
+
|
21 |
+
# Extract discrete codes from EnCodec
|
22 |
+
with torch.no_grad():
|
23 |
+
encoded_frames = model.encode(wav)
|
24 |
+
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
|
test/eval_dataset_test.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from modeling.audiobart import AudioBartForConditionalGeneration
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from data.collator import EncodecCollator
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
|
11 |
+
if __name__=="__main__":
|
12 |
+
model = AudioBartForConditionalGeneration.from_pretrained('bart/model')
|
13 |
+
basepath = "/data/jyk/aac_dataset/clotho/encodec/"
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')
|
15 |
+
data_files = {"validation": "csv/valid_allcaps.csv"}
|
16 |
+
num_captions = 5
|
17 |
+
|
18 |
+
raw_dataset = load_dataset("csv", data_files=data_files)
|
19 |
+
|
20 |
+
def preprocess_eval(example):
|
21 |
+
path = example['file_path']
|
22 |
+
encodec = np.load(os.path.join(basepath, path))
|
23 |
+
if encodec.shape[0]>1022:
|
24 |
+
encodec = encodec[:1022, :]
|
25 |
+
attention_mask = np.ones(encodec.shape[0]+2).astype(np.int64)
|
26 |
+
captions = []
|
27 |
+
for i in range(1, num_captions+1):
|
28 |
+
captions.append(example['caption_'+str(i)])
|
29 |
+
|
30 |
+
return {'input_ids': encodec, 'attention_mask': attention_mask, 'captions': captions}
|
31 |
+
|
32 |
+
train_dataset = raw_dataset['validation'].map(preprocess_eval)
|
33 |
+
train_dataset.set_format('pt', columns=['input_ids', 'attention_mask'], output_all_columns=True)
|
34 |
+
# train_dataset.remove_columns('file_path', 'caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5')
|
35 |
+
data_collator = EncodecCollator(tokenizer=tokenizer, model=model, return_tensors="pt")
|
36 |
+
|
37 |
+
train_dataloader = DataLoader(
|
38 |
+
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=16)
|
39 |
+
|
40 |
+
for idx, batch in enumerate(train_dataloader):
|
41 |
+
output = model.generate(**batch, max_length=100)
|
42 |
+
print(output)
|
test/masking_test.ipynb
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"import torch"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"probability_matrix = torch.full((8, 15), 0.15)"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 4,
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [
|
35 |
+
{
|
36 |
+
"data": {
|
37 |
+
"text/plain": [
|
38 |
+
"torch.Size([8, 15])"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
"execution_count": 4,
|
42 |
+
"metadata": {},
|
43 |
+
"output_type": "execute_result"
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"source": [
|
47 |
+
"probability_matrix.shape"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 9,
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"data": {
|
57 |
+
"text/plain": [
|
58 |
+
"tensor([[0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
59 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
60 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
61 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
62 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
63 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
64 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
65 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
66 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
67 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
68 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
69 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
70 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
71 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
|
72 |
+
" [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
|
73 |
+
" 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500]])"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
"execution_count": 9,
|
77 |
+
"metadata": {},
|
78 |
+
"output_type": "execute_result"
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"probability_matrix.masked_fill_(torch.tensor(0, dtype=torch.bool), value=0.0)"
|
83 |
+
]
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"cell_type": "code",
|
87 |
+
"execution_count": null,
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"masked_indices = torch.bernoulli()"
|
92 |
+
]
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"metadata": {
|
96 |
+
"kernelspec": {
|
97 |
+
"display_name": "base",
|
98 |
+
"language": "python",
|
99 |
+
"name": "python3"
|
100 |
+
},
|
101 |
+
"language_info": {
|
102 |
+
"codemirror_mode": {
|
103 |
+
"name": "ipython",
|
104 |
+
"version": 3
|
105 |
+
},
|
106 |
+
"file_extension": ".py",
|
107 |
+
"mimetype": "text/x-python",
|
108 |
+
"name": "python",
|
109 |
+
"nbconvert_exporter": "python",
|
110 |
+
"pygments_lexer": "ipython3",
|
111 |
+
"version": "3.9.12"
|
112 |
+
},
|
113 |
+
"orig_nbformat": 4
|
114 |
+
},
|
115 |
+
"nbformat": 4,
|
116 |
+
"nbformat_minor": 2
|
117 |
+
}
|
test/metric_test.ipynb
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 8,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import evaluate"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 9,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [
|
17 |
+
{
|
18 |
+
"name": "stderr",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
|
22 |
+
"[nltk_data] Package wordnet is already up-to-date!\n",
|
23 |
+
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
|
24 |
+
"[nltk_data] Package punkt is already up-to-date!\n",
|
25 |
+
"[nltk_data] Downloading package omw-1.4 to /root/nltk_data...\n",
|
26 |
+
"[nltk_data] Package omw-1.4 is already up-to-date!\n"
|
27 |
+
]
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"source": [
|
31 |
+
"metric = evaluate.load(\"meteor\")"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 6,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"import pandas as pd\n",
|
41 |
+
"df = pd.read_csv(\"csv/predictions.csv\")"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 7,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"predictions = []\n",
|
51 |
+
"references = []\n",
|
52 |
+
"for idx in range(len(df)):\n",
|
53 |
+
" predictions.append(df.loc[idx]['prediction'])\n",
|
54 |
+
" reference = [df.loc[idx]['caption1'],df.loc[idx]['caption2'],df.loc[idx]['caption3'],df.loc[idx]['caption4'],df.loc[idx]['caption5'] ]\n",
|
55 |
+
" references.append(reference)"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 8,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"from aac_metrics import evaluate\n",
|
65 |
+
"corpus_score = evaluate(predictions, references)"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": 4,
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [
|
73 |
+
{
|
74 |
+
"data": {
|
75 |
+
"text/plain": [
|
76 |
+
"{'bleu_1': tensor(0.3913, dtype=torch.float64),\n",
|
77 |
+
" 'bleu_2': tensor(0.1931, dtype=torch.float64),\n",
|
78 |
+
" 'bleu_3': tensor(0.1065, dtype=torch.float64),\n",
|
79 |
+
" 'bleu_4': tensor(0.0569, dtype=torch.float64),\n",
|
80 |
+
" 'meteor': tensor(0.1197, dtype=torch.float64),\n",
|
81 |
+
" 'rouge_l': tensor(0.2745, dtype=torch.float64),\n",
|
82 |
+
" 'cider_d': tensor(0.1235, dtype=torch.float64),\n",
|
83 |
+
" 'spice': tensor(0.0670, dtype=torch.float64),\n",
|
84 |
+
" 'spider': tensor(0.0953, dtype=torch.float64)}"
|
85 |
+
]
|
86 |
+
},
|
87 |
+
"execution_count": 4,
|
88 |
+
"metadata": {},
|
89 |
+
"output_type": "execute_result"
|
90 |
+
}
|
91 |
+
],
|
92 |
+
"source": [
|
93 |
+
"corpus_score[0]"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": 8,
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"name": "stdout",
|
103 |
+
"output_type": "stream",
|
104 |
+
"text": [
|
105 |
+
"{'bleu_1': 0.3912776883574468, 'bleu_2': 0.19312066269135236, 'bleu_3': 0.10651188216812753, 'bleu_4': 0.05690269475018141, 'meteor': 0.11968742992878356, 'rouge_l': 0.2744644068893943, 'cider_d': 0.12347016800968286, 'spice': 0.06704068138550699, 'spider': 0.09525542469759493}\n"
|
106 |
+
]
|
107 |
+
}
|
108 |
+
],
|
109 |
+
"source": [
|
110 |
+
"print({k: v.item() for k, v in corpus_score[0].items()})"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "code",
|
115 |
+
"execution_count": 13,
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"results = metric.compute(predictions=predictions, references=references)"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 14,
|
125 |
+
"metadata": {},
|
126 |
+
"outputs": [
|
127 |
+
{
|
128 |
+
"name": "stdout",
|
129 |
+
"output_type": "stream",
|
130 |
+
"text": [
|
131 |
+
"{'meteor': 0.26686702985116983}\n"
|
132 |
+
]
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"source": [
|
136 |
+
"print(results)"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 8,
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [],
|
144 |
+
"source": [
|
145 |
+
"bleu = evaluate.load(\"bleu\")"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": 9,
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [],
|
153 |
+
"source": [
|
154 |
+
"from transformers import AutoTokenizer\n",
|
155 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large\")"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": 11,
|
161 |
+
"metadata": {},
|
162 |
+
"outputs": [],
|
163 |
+
"source": [
|
164 |
+
"bleu_result = bleu.compute(predictions=predictions, references=references, max_order=4)"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": 12,
|
170 |
+
"metadata": {},
|
171 |
+
"outputs": [
|
172 |
+
{
|
173 |
+
"name": "stdout",
|
174 |
+
"output_type": "stream",
|
175 |
+
"text": [
|
176 |
+
"{'bleu': 0.06128958043343902, 'precisions': [0.42544588056899413, 0.09036238675413934, 0.031210136916404455, 0.01176031360836289], 'brevity_penalty': 1.0, 'length_ratio': 1.3508583690987124, 'translation_length': 13849, 'reference_length': 10252}\n"
|
177 |
+
]
|
178 |
+
}
|
179 |
+
],
|
180 |
+
"source": [
|
181 |
+
"print(bleu_result)"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": 5,
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [
|
189 |
+
{
|
190 |
+
"ename": "AttributeError",
|
191 |
+
"evalue": "'function' object has no attribute 'load'",
|
192 |
+
"output_type": "error",
|
193 |
+
"traceback": [
|
194 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
195 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
196 |
+
"\u001b[1;32m/workspace/audiobart/metric_test.ipynb Cell 13\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f617564696f62617274222c2273657474696e6773223a7b22686f7374223a227373683a2f2f3138332e3131302e36322e3639227d7d/workspace/audiobart/metric_test.ipynb#X14sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m rouge_metric \u001b[39m=\u001b[39m evaluate\u001b[39m.\u001b[39;49mload(\u001b[39m\"\u001b[39m\u001b[39mrouge\u001b[39m\u001b[39m\"\u001b[39m)\n",
|
197 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'function' object has no attribute 'load'"
|
198 |
+
]
|
199 |
+
}
|
200 |
+
],
|
201 |
+
"source": [
|
202 |
+
"rouge_metric = evaluate.load(\"rouge\")"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": 14,
|
208 |
+
"metadata": {},
|
209 |
+
"outputs": [],
|
210 |
+
"source": [
|
211 |
+
"rouge_result = rouge_metric.compute(predictions=predictions, references=references)"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 15,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [
|
219 |
+
{
|
220 |
+
"data": {
|
221 |
+
"text/plain": [
|
222 |
+
"{'rouge1': 0.30500784605917763,\n",
|
223 |
+
" 'rouge2': 0.08778194034686765,\n",
|
224 |
+
" 'rougeL': 0.2707178803695874,\n",
|
225 |
+
" 'rougeLsum': 0.27045227295118685}"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
"execution_count": 15,
|
229 |
+
"metadata": {},
|
230 |
+
"output_type": "execute_result"
|
231 |
+
}
|
232 |
+
],
|
233 |
+
"source": [
|
234 |
+
"rouge_result"
|
235 |
+
]
|
236 |
+
}
|
237 |
+
],
|
238 |
+
"metadata": {
|
239 |
+
"kernelspec": {
|
240 |
+
"display_name": "base",
|
241 |
+
"language": "python",
|
242 |
+
"name": "python3"
|
243 |
+
},
|
244 |
+
"language_info": {
|
245 |
+
"codemirror_mode": {
|
246 |
+
"name": "ipython",
|
247 |
+
"version": 3
|
248 |
+
},
|
249 |
+
"file_extension": ".py",
|
250 |
+
"mimetype": "text/x-python",
|
251 |
+
"name": "python",
|
252 |
+
"nbconvert_exporter": "python",
|
253 |
+
"pygments_lexer": "ipython3",
|
254 |
+
"version": "3.9.12"
|
255 |
+
},
|
256 |
+
"orig_nbformat": 4
|
257 |
+
},
|
258 |
+
"nbformat": 4,
|
259 |
+
"nbformat_minor": 2
|
260 |
+
}
|
test/subsample_test.ipynb
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import sys\n",
|
10 |
+
"sys.path.append(\"..\")"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 3,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"name": "stderr",
|
20 |
+
"output_type": "stream",
|
21 |
+
"text": [
|
22 |
+
"/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
23 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
24 |
+
]
|
25 |
+
}
|
26 |
+
],
|
27 |
+
"source": [
|
28 |
+
"from modeling.audiobart import Subsampler"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": 4,
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"subsampler = Subsampler(1024, 3, 2)"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": 5,
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"from utils import count_parameters"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 6,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"data": {
|
56 |
+
"text/plain": [
|
57 |
+
"1053696"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
"execution_count": 6,
|
61 |
+
"metadata": {},
|
62 |
+
"output_type": "execute_result"
|
63 |
+
}
|
64 |
+
],
|
65 |
+
"source": [
|
66 |
+
"count_parameters(subsampler)"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": 7,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"import torch"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 25,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [
|
83 |
+
{
|
84 |
+
"name": "stdout",
|
85 |
+
"output_type": "stream",
|
86 |
+
"text": [
|
87 |
+
"torch.Size([8, 1023, 1024])\n"
|
88 |
+
]
|
89 |
+
}
|
90 |
+
],
|
91 |
+
"source": [
|
92 |
+
"input = torch.randn(8, 4095, 1024)\n",
|
93 |
+
"output = subsampler(input)\n",
|
94 |
+
"output = subsampler(output)\n",
|
95 |
+
"print(output.shape)"
|
96 |
+
]
|
97 |
+
}
|
98 |
+
],
|
99 |
+
"metadata": {
|
100 |
+
"kernelspec": {
|
101 |
+
"display_name": "base",
|
102 |
+
"language": "python",
|
103 |
+
"name": "python3"
|
104 |
+
},
|
105 |
+
"language_info": {
|
106 |
+
"codemirror_mode": {
|
107 |
+
"name": "ipython",
|
108 |
+
"version": 3
|
109 |
+
},
|
110 |
+
"file_extension": ".py",
|
111 |
+
"mimetype": "text/x-python",
|
112 |
+
"name": "python",
|
113 |
+
"nbconvert_exporter": "python",
|
114 |
+
"pygments_lexer": "ipython3",
|
115 |
+
"version": "3.9.12"
|
116 |
+
},
|
117 |
+
"orig_nbformat": 4
|
118 |
+
},
|
119 |
+
"nbformat": 4,
|
120 |
+
"nbformat_minor": 2
|
121 |
+
}
|
test/train_test.ipynb
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from datasets import load_dataset\n",
|
10 |
+
"from transformers import AutoTokenizer"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 3,
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import os\n",
|
20 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"4,5,6,7\""
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 4,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"basepath = \"/data/jyk/aac_dataset/clotho/encodec/\"\n",
|
30 |
+
"tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 5,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"data_files = {\"train\": \"csv/train_short.csv\", \"validation\": \"csv/valid_short.csv\"}"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 6,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [
|
47 |
+
{
|
48 |
+
"name": "stderr",
|
49 |
+
"output_type": "stream",
|
50 |
+
"text": [
|
51 |
+
"Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-8533483370f473b7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)\n",
|
52 |
+
"100%|██████████| 2/2 [00:00<00:00, 923.96it/s]\n"
|
53 |
+
]
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"source": [
|
57 |
+
"raw_dataset = load_dataset(\"csv\", data_files=data_files)"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 7,
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"data": {
|
67 |
+
"text/plain": [
|
68 |
+
"Dataset({\n",
|
69 |
+
" features: ['file_path', 'caption'],\n",
|
70 |
+
" num_rows: 19175\n",
|
71 |
+
"})"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
"execution_count": 7,
|
75 |
+
"metadata": {},
|
76 |
+
"output_type": "execute_result"
|
77 |
+
}
|
78 |
+
],
|
79 |
+
"source": [
|
80 |
+
"raw_dataset['train']"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 10,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"from data.collator import EncodecCollator\n",
|
90 |
+
"import numpy as np\n",
|
91 |
+
"import os"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 11,
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"def preprocessing(example):\n",
|
101 |
+
" path = example['file_path']\n",
|
102 |
+
" encodec = np.load(os.path.join(basepath, path))\n",
|
103 |
+
" if encodec.shape[0]>1022:\n",
|
104 |
+
" encodec = encodec[:1022, :]\n",
|
105 |
+
" attention_mask = np.ones(encodec.shape[0]+2)\n",
|
106 |
+
" target_text = tokenizer(text_target=example['caption'])\n",
|
107 |
+
"\n",
|
108 |
+
" return {'input_ids': encodec , 'attention_mask': attention_mask, 'labels': target_text['input_ids'], 'decoder_attention_mask': target_text['attention_mask']}\n"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 12,
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [
|
116 |
+
{
|
117 |
+
"name": "stderr",
|
118 |
+
"output_type": "stream",
|
119 |
+
"text": [
|
120 |
+
" \r"
|
121 |
+
]
|
122 |
+
}
|
123 |
+
],
|
124 |
+
"source": [
|
125 |
+
"train_dataset = raw_dataset['train'].map(preprocessing, num_proc=16)"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": 13,
|
131 |
+
"metadata": {},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"train_dataset.set_format(\"np\", columns=['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": 32,
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [
|
142 |
+
{
|
143 |
+
"name": "stderr",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-8533483370f473b7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-a3db71731640afd3_*_of_00016.arrow\n"
|
147 |
+
]
|
148 |
+
}
|
149 |
+
],
|
150 |
+
"source": [
|
151 |
+
"valid_dataset = raw_dataset['validation'].map(preprocessing, num_proc=16)\n",
|
152 |
+
"valid_dataset.set_format(\"np\", columns=['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": 16,
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [
|
160 |
+
{
|
161 |
+
"name": "stderr",
|
162 |
+
"output_type": "stream",
|
163 |
+
"text": [
|
164 |
+
"Some weights of AudioBartForConditionalGeneration were not initialized from the model checkpoint at bart/model and are newly initialized: ['model.encodec_embeddings.6.weight', 'model.encodec_embeddings.4.weight', 'model.encodec_embeddings.1.weight', 'model.encodec_embeddings.7.weight', 'model.encodec_embeddings.5.weight', 'model.encodec_embeddings.3.weight', 'model.encodec_embeddings.2.weight', 'model.encodec_embeddings.0.weight']\n",
|
165 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
166 |
+
]
|
167 |
+
}
|
168 |
+
],
|
169 |
+
"source": [
|
170 |
+
"from modeling.audiobart import AudioBartForConditionalGeneration\n",
|
171 |
+
"model = AudioBartForConditionalGeneration.from_pretrained('bart/model')"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": 25,
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"name": "stdout",
|
181 |
+
"output_type": "stream",
|
182 |
+
"text": [
|
183 |
+
"414688256\n"
|
184 |
+
]
|
185 |
+
}
|
186 |
+
],
|
187 |
+
"source": [
|
188 |
+
"from utils import count_parameters\n",
|
189 |
+
"print(count_parameters(model))"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": 17,
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"data_collator = EncodecCollator(tokenizer=tokenizer, model=model, return_tensors=\"pt\")"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": 19,
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 36,
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"training_args = Seq2SeqTrainingArguments('summary_test', per_gpu_train_batch_size=16)"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 37,
|
222 |
+
"metadata": {},
|
223 |
+
"outputs": [
|
224 |
+
{
|
225 |
+
"name": "stderr",
|
226 |
+
"output_type": "stream",
|
227 |
+
"text": [
|
228 |
+
"Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.\n"
|
229 |
+
]
|
230 |
+
}
|
231 |
+
],
|
232 |
+
"source": [
|
233 |
+
"trainer = Seq2SeqTrainer(\n",
|
234 |
+
" model, training_args, train_dataset=valid_dataset, eval_dataset=valid_dataset, data_collator=data_collator, tokenizer=tokenizer\n",
|
235 |
+
")"
|
236 |
+
]
|
237 |
+
}
|
238 |
+
],
|
239 |
+
"metadata": {
|
240 |
+
"kernelspec": {
|
241 |
+
"display_name": "base",
|
242 |
+
"language": "python",
|
243 |
+
"name": "python3"
|
244 |
+
},
|
245 |
+
"language_info": {
|
246 |
+
"codemirror_mode": {
|
247 |
+
"name": "ipython",
|
248 |
+
"version": 3
|
249 |
+
},
|
250 |
+
"file_extension": ".py",
|
251 |
+
"mimetype": "text/x-python",
|
252 |
+
"name": "python",
|
253 |
+
"nbconvert_exporter": "python",
|
254 |
+
"pygments_lexer": "ipython3",
|
255 |
+
"version": "3.9.12"
|
256 |
+
},
|
257 |
+
"orig_nbformat": 4
|
258 |
+
},
|
259 |
+
"nbformat": 4,
|
260 |
+
"nbformat_minor": 2
|
261 |
+
}
|
test/train_test.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from modeling.audiobart import AudioBartForConditionalGeneration
|
4 |
+
from data.collator import EncodecCollator
|
5 |
+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
|
11 |
+
|
12 |
+
if __name__=="__main__":
|
13 |
+
model = AudioBartForConditionalGeneration.from_pretrained('bart/model')
|
14 |
+
basepath = "/data/jyk/aac_dataset/clotho/encodec/"
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')
|
16 |
+
data_files = {"train": "csv/train_short.csv", "validation": "csv/valid_short.csv"}
|
17 |
+
|
18 |
+
raw_dataset = load_dataset("csv", data_files=data_files)
|
19 |
+
|
20 |
+
def preprocessing(example):
|
21 |
+
path = example['file_path']
|
22 |
+
encodec = np.load(os.path.join(basepath, path))
|
23 |
+
if encodec.shape[0]>1022:
|
24 |
+
encodec = encodec[:1022, :]
|
25 |
+
attention_mask = np.ones(encodec.shape[0]+2)
|
26 |
+
target_text = tokenizer(text_target=example['caption'])
|
27 |
+
|
28 |
+
return {'input_ids': encodec , 'attention_mask': attention_mask, 'labels': target_text['input_ids'], 'decoder_attention_mask': target_text['attention_mask']}
|
29 |
+
|
30 |
+
train_dataset = raw_dataset['validation'].map(preprocessing)
|
31 |
+
train_dataset.set_format("pt", columns=['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])
|
32 |
+
|
33 |
+
data_collator = EncodecCollator(tokenizer=tokenizer, model=model, return_tensors="pt")
|
34 |
+
|
35 |
+
training_args = Seq2SeqTrainingArguments('summary_test', per_gpu_train_batch_size=20)
|
36 |
+
|
37 |
+
trainer = Seq2SeqTrainer(
|
38 |
+
model, training_args, train_dataset=train_dataset, eval_dataset=train_dataset, data_collator=data_collator, tokenizer=tokenizer
|
39 |
+
)
|
40 |
+
|
41 |
+
trainer.train()
|
train.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
import logging
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
|
21 |
+
import datasets
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import transformers
|
25 |
+
from aac_metrics import evaluate
|
26 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
27 |
+
from accelerate.logging import get_logger
|
28 |
+
from accelerate.utils import set_seed
|
29 |
+
from datasets import load_dataset
|
30 |
+
from omegaconf import OmegaConf
|
31 |
+
from torch.utils.data import DataLoader
|
32 |
+
from tqdm.auto import tqdm
|
33 |
+
from transformers import (
|
34 |
+
AutoTokenizer,
|
35 |
+
BartConfig,
|
36 |
+
get_inverse_sqrt_schedule,
|
37 |
+
get_scheduler,
|
38 |
+
)
|
39 |
+
|
40 |
+
from data.collator import DataCollatorForEnClapBart
|
41 |
+
from data.preprocess import Preprocessor
|
42 |
+
from modeling.enclap_bart import EnClapBartForConditionalGeneration
|
43 |
+
|
44 |
+
logger = get_logger(__name__)
|
45 |
+
metric_list = ["meteor", "spider"]
|
46 |
+
|
47 |
+
|
48 |
+
def main():
|
49 |
+
# Load Configuration
|
50 |
+
cfg_path = sys.argv[1]
|
51 |
+
args = OmegaConf.load(cfg_path)
|
52 |
+
|
53 |
+
# Initialize Logging
|
54 |
+
accelerator_log_kwargs = {}
|
55 |
+
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
56 |
+
if args.with_tracking:
|
57 |
+
accelerator_log_kwargs["log_with"] = args.report_to
|
58 |
+
accelerator_log_kwargs["project_dir"] = args.output_dir
|
59 |
+
|
60 |
+
# Initialize Accelerator
|
61 |
+
accelerator = Accelerator(
|
62 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
63 |
+
split_batches=args.split_batches,
|
64 |
+
kwargs_handlers=[ddp_kwargs],
|
65 |
+
**accelerator_log_kwargs,
|
66 |
+
)
|
67 |
+
# Handle the repository creation
|
68 |
+
if accelerator.is_main_process:
|
69 |
+
if args.output_dir is not None:
|
70 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
71 |
+
with open(os.path.join(args.output_dir, "args.yaml"), "w") as f:
|
72 |
+
OmegaConf.save(args, f)
|
73 |
+
accelerator.wait_for_everyone()
|
74 |
+
|
75 |
+
# Make one log on every process with the configuration for debugging.
|
76 |
+
logging.basicConfig(
|
77 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
78 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
79 |
+
level=logging.INFO,
|
80 |
+
)
|
81 |
+
file_handler = logging.FileHandler(os.path.join(args.output_dir, "train_log.txt"))
|
82 |
+
logger.logger.addHandler(file_handler)
|
83 |
+
logger.info(accelerator.state, main_process_only=False)
|
84 |
+
if accelerator.is_local_main_process:
|
85 |
+
datasets.utils.logging.set_verbosity_warning()
|
86 |
+
transformers.utils.logging.set_verbosity_warning()
|
87 |
+
else:
|
88 |
+
datasets.utils.logging.set_verbosity_error()
|
89 |
+
transformers.utils.logging.set_verbosity_error()
|
90 |
+
|
91 |
+
# If passed along, set the training seed now.
|
92 |
+
if args.seed is not None:
|
93 |
+
set_seed(args.seed)
|
94 |
+
|
95 |
+
# Get the datasets
|
96 |
+
data_files = {}
|
97 |
+
data_files_eval = {}
|
98 |
+
if args.train_file is not None:
|
99 |
+
data_files["train"] = args.train_file
|
100 |
+
if args.validation_file is not None:
|
101 |
+
data_files_eval["validation"] = args.validation_file
|
102 |
+
|
103 |
+
extension = args.train_file.split(".")[-1]
|
104 |
+
raw_datasets = load_dataset(extension, data_files=data_files)
|
105 |
+
raw_datasets_eval = load_dataset(extension, data_files=data_files_eval)
|
106 |
+
|
107 |
+
# Load pretrained model and tokenizer
|
108 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
|
109 |
+
if args.config_name_or_path is not None:
|
110 |
+
config = BartConfig.from_pretrained(args.config_name_or_path)
|
111 |
+
else:
|
112 |
+
config = None
|
113 |
+
|
114 |
+
if args.model_name_or_path is not None:
|
115 |
+
if config is None:
|
116 |
+
model = EnClapBartForConditionalGeneration.from_pretrained(
|
117 |
+
args.model_name_or_path
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
model = EnClapBartForConditionalGeneration.from_pretrained(
|
121 |
+
args.model_name_or_path, config=config
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
model = EnClapBartForConditionalGeneration(config=config)
|
125 |
+
|
126 |
+
# Set the generation config
|
127 |
+
if args.val_max_target_length is None:
|
128 |
+
args.val_max_target_length = args.max_target_length
|
129 |
+
|
130 |
+
# Set max encodec length based on the shape of the positional encoding
|
131 |
+
max_encodec_length = model.config.max_position_embeddings - 2
|
132 |
+
label_pad_token_id = (
|
133 |
+
-100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
134 |
+
)
|
135 |
+
preprocessor = Preprocessor(
|
136 |
+
args.encodec_base_path,
|
137 |
+
args.clap_base_path,
|
138 |
+
tokenizer,
|
139 |
+
model.config.max_position_embeddings,
|
140 |
+
args.encodec_masking_prob,
|
141 |
+
args.encodec_masking_span,
|
142 |
+
label_pad_token_id,
|
143 |
+
model.config.encodec_vocab_size,
|
144 |
+
args.eval_num_captions,
|
145 |
+
)
|
146 |
+
|
147 |
+
with accelerator.main_process_first():
|
148 |
+
train_dataset = raw_datasets["train"].map(
|
149 |
+
preprocessor.preprocess_train,
|
150 |
+
num_proc=args.preprocessing_num_workers,
|
151 |
+
load_from_cache_file=not args.overwrite_cache,
|
152 |
+
desc="Running tokenizer on dataset",
|
153 |
+
)
|
154 |
+
train_dataset.set_format(
|
155 |
+
"pt",
|
156 |
+
columns=[
|
157 |
+
"input_ids",
|
158 |
+
"attention_mask",
|
159 |
+
"clap",
|
160 |
+
"labels",
|
161 |
+
"decoder_attention_mask",
|
162 |
+
],
|
163 |
+
)
|
164 |
+
|
165 |
+
# Temporarily set max_target_length for validation.
|
166 |
+
eval_dataset = raw_datasets_eval["validation"].map(
|
167 |
+
preprocessor.preprocess_eval,
|
168 |
+
num_proc=args.preprocessing_num_workers,
|
169 |
+
load_from_cache_file=not args.overwrite_cache,
|
170 |
+
desc="Running tokenizer on dataset",
|
171 |
+
)
|
172 |
+
eval_dataset.set_format(
|
173 |
+
"pt",
|
174 |
+
columns=["input_ids", "attention_mask", "clap"],
|
175 |
+
output_all_columns=True,
|
176 |
+
)
|
177 |
+
|
178 |
+
train_data_collator = DataCollatorForEnClapBart(
|
179 |
+
tokenizer=tokenizer,
|
180 |
+
model=model,
|
181 |
+
return_tensors="pt",
|
182 |
+
label_pad_token_id=label_pad_token_id,
|
183 |
+
max_length=max_encodec_length,
|
184 |
+
encodec_masking_prob=args.encodec_masking_prob,
|
185 |
+
encodec_masking_span=args.encodec_masking_span,
|
186 |
+
)
|
187 |
+
valid_data_collator = DataCollatorForEnClapBart(
|
188 |
+
tokenizer=tokenizer,
|
189 |
+
model=model,
|
190 |
+
return_tensors="pt",
|
191 |
+
label_pad_token_id=label_pad_token_id,
|
192 |
+
max_length=max_encodec_length,
|
193 |
+
)
|
194 |
+
|
195 |
+
train_dataloader = DataLoader(
|
196 |
+
train_dataset,
|
197 |
+
shuffle=True,
|
198 |
+
collate_fn=train_data_collator,
|
199 |
+
batch_size=args.per_device_train_batch_size,
|
200 |
+
)
|
201 |
+
eval_dataloader = DataLoader(
|
202 |
+
eval_dataset,
|
203 |
+
collate_fn=valid_data_collator,
|
204 |
+
batch_size=args.per_device_eval_batch_size,
|
205 |
+
)
|
206 |
+
|
207 |
+
# Optimizer
|
208 |
+
# Split weights in two groups, one with weight decay and the other not.
|
209 |
+
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
|
210 |
+
optimizer_grouped_parameters = [
|
211 |
+
{
|
212 |
+
"params": [
|
213 |
+
p
|
214 |
+
for n, p in model.named_parameters()
|
215 |
+
if not any(nd in n for nd in no_decay)
|
216 |
+
],
|
217 |
+
"weight_decay": args.weight_decay,
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"params": [
|
221 |
+
p
|
222 |
+
for n, p in model.named_parameters()
|
223 |
+
if any(nd in n for nd in no_decay)
|
224 |
+
],
|
225 |
+
"weight_decay": 0.0,
|
226 |
+
},
|
227 |
+
]
|
228 |
+
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
|
229 |
+
|
230 |
+
# Scheduler and math around the number of training steps.
|
231 |
+
overrode_max_train_steps = False
|
232 |
+
num_update_steps_per_epoch = math.ceil(
|
233 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
234 |
+
)
|
235 |
+
if args.max_train_steps is None:
|
236 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
237 |
+
overrode_max_train_steps = True
|
238 |
+
|
239 |
+
if args.lr_scheduler_type == "inverse_sqrt" and hasattr(args, "time_scale"):
|
240 |
+
lr_scheduler = get_inverse_sqrt_schedule(
|
241 |
+
optimizer=optimizer,
|
242 |
+
num_warmup_steps=args.num_warmup_steps,
|
243 |
+
timescale=args.time_scale,
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
lr_scheduler = get_scheduler(
|
247 |
+
name=args.lr_scheduler_type,
|
248 |
+
optimizer=optimizer,
|
249 |
+
num_warmup_steps=args.num_warmup_steps,
|
250 |
+
num_training_steps=args.max_train_steps,
|
251 |
+
)
|
252 |
+
|
253 |
+
# Prepare everything with our `accelerator`.
|
254 |
+
(
|
255 |
+
model,
|
256 |
+
optimizer,
|
257 |
+
train_dataloader,
|
258 |
+
eval_dataloader,
|
259 |
+
lr_scheduler,
|
260 |
+
) = accelerator.prepare(
|
261 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
262 |
+
)
|
263 |
+
|
264 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
265 |
+
num_update_steps_per_epoch = math.ceil(
|
266 |
+
len(train_dataloader) / args.gradient_accumulation_steps
|
267 |
+
)
|
268 |
+
if overrode_max_train_steps:
|
269 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
270 |
+
# Afterwards we recalculate our number of training epochs
|
271 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
272 |
+
|
273 |
+
# Figure out how many steps we should save the Accelerator states
|
274 |
+
checkpointing_steps = args.checkpointing_steps
|
275 |
+
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
276 |
+
checkpointing_steps = int(checkpointing_steps)
|
277 |
+
|
278 |
+
# The trackers initializes automatically on the main process.
|
279 |
+
if args.with_tracking:
|
280 |
+
accelerator.init_trackers(args.logging_dir)
|
281 |
+
|
282 |
+
# Train!
|
283 |
+
total_batch_size = (
|
284 |
+
args.per_device_train_batch_size
|
285 |
+
* accelerator.num_processes
|
286 |
+
* args.gradient_accumulation_steps
|
287 |
+
)
|
288 |
+
|
289 |
+
if args.split_batches:
|
290 |
+
total_batch_size = int(total_batch_size / accelerator.num_processes)
|
291 |
+
|
292 |
+
logger.info("***** Running training *****")
|
293 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
294 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
295 |
+
logger.info(
|
296 |
+
f" Instantaneous batch size per device = {args.per_device_train_batch_size}"
|
297 |
+
)
|
298 |
+
logger.info(
|
299 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
300 |
+
)
|
301 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
302 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
303 |
+
|
304 |
+
completed_steps = 0
|
305 |
+
starting_epoch = 0
|
306 |
+
# Potentially load in the weights and states from a previous save
|
307 |
+
if not args.overwrite_output_dir and os.path.exists(
|
308 |
+
os.path.join(args.output_dir, "checkpoints")
|
309 |
+
):
|
310 |
+
if args.resume_from_checkpoint is not None:
|
311 |
+
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
312 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
313 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
314 |
+
else:
|
315 |
+
# Get the most recent checkpoint
|
316 |
+
dirs = [
|
317 |
+
f
|
318 |
+
for f in os.scandir(os.path.join(args.output_dir, "checkpoints"))
|
319 |
+
if f.is_dir()
|
320 |
+
]
|
321 |
+
dirs.sort(key=os.path.getctime)
|
322 |
+
path = dirs[
|
323 |
+
-1
|
324 |
+
].name # Sorts folders by date modified, most recent checkpoint is the last
|
325 |
+
accelerator.print(f"Resumed from checkpoint: {dirs[-1]}")
|
326 |
+
accelerator.load_state(dirs[-1])
|
327 |
+
# Extract `epoch_{i}` or `step_{i}`
|
328 |
+
training_difference = os.path.splitext(path)[0]
|
329 |
+
|
330 |
+
if "epoch" in training_difference:
|
331 |
+
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
|
332 |
+
resume_step = None
|
333 |
+
completed_steps = starting_epoch * num_update_steps_per_epoch
|
334 |
+
else:
|
335 |
+
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
336 |
+
resume_step = (
|
337 |
+
int(training_difference.replace("step_", ""))
|
338 |
+
* args.gradient_accumulation_steps
|
339 |
+
)
|
340 |
+
starting_epoch = resume_step // len(train_dataloader)
|
341 |
+
resume_step -= starting_epoch * len(train_dataloader)
|
342 |
+
completed_steps = resume_step // args.gradient_accumulation_stepp
|
343 |
+
|
344 |
+
# update the progress_bar if load from checkpoint
|
345 |
+
if args.with_tracking:
|
346 |
+
total_loss = 0
|
347 |
+
logging_loss = 0
|
348 |
+
before_epoch_loss = 0
|
349 |
+
|
350 |
+
if args.encodec_masking_prob > 0:
|
351 |
+
total_encodec_loss = 0
|
352 |
+
logging_encodec_loss = 0
|
353 |
+
before_epoch_encodec_loss = 0
|
354 |
+
|
355 |
+
for epoch in range(starting_epoch, args.num_train_epochs):
|
356 |
+
model.train()
|
357 |
+
if (
|
358 |
+
args.resume_from_checkpoint
|
359 |
+
and epoch == starting_epoch
|
360 |
+
and resume_step is not None
|
361 |
+
):
|
362 |
+
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
|
363 |
+
active_dataloader = accelerator.skip_first_batches(
|
364 |
+
train_dataloader, resume_step
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
active_dataloader = train_dataloader
|
368 |
+
logger.info(f"***** Running epoch {epoch} *****")
|
369 |
+
epoch_iterator = tqdm(
|
370 |
+
active_dataloader,
|
371 |
+
desc="Training",
|
372 |
+
disable=not accelerator.is_local_main_process,
|
373 |
+
dynamic_ncols=True,
|
374 |
+
colour="CYAN",
|
375 |
+
)
|
376 |
+
for step, batch in enumerate(epoch_iterator):
|
377 |
+
with accelerator.accumulate(model):
|
378 |
+
outputs = model(**batch)
|
379 |
+
loss = outputs.loss
|
380 |
+
# We keep track of the loss at each epoch
|
381 |
+
if args.with_tracking:
|
382 |
+
total_loss += outputs.lm_loss.item()
|
383 |
+
if args.encodec_masking_prob > 0:
|
384 |
+
if outputs.encodec_loss is not None:
|
385 |
+
total_encodec_loss += outputs.encodec_loss.item()
|
386 |
+
accelerator.backward(loss)
|
387 |
+
if accelerator.sync_gradients:
|
388 |
+
accelerator.clip_grad_norm_(
|
389 |
+
model.parameters(), max_norm=args.max_grad_norm
|
390 |
+
)
|
391 |
+
optimizer.step()
|
392 |
+
lr_scheduler.step()
|
393 |
+
optimizer.zero_grad()
|
394 |
+
|
395 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
396 |
+
if accelerator.sync_gradients:
|
397 |
+
completed_steps += 1
|
398 |
+
# Add loss information to tqdm
|
399 |
+
epoch_iterator.set_postfix(loss=total_loss / completed_steps)
|
400 |
+
|
401 |
+
if completed_steps % args.logging_steps == 0:
|
402 |
+
train_log = {
|
403 |
+
"train/learning_rate": lr_scheduler.get_last_lr()[0]
|
404 |
+
}
|
405 |
+
train_log["train/loss"] = (
|
406 |
+
total_loss - logging_loss
|
407 |
+
) / args.logging_steps
|
408 |
+
logging_loss = total_loss
|
409 |
+
if args.encodec_masking_prob > 0:
|
410 |
+
train_log["train/encodec_loss"] = (
|
411 |
+
total_encodec_loss - logging_encodec_loss
|
412 |
+
) / args.logging_steps
|
413 |
+
logging_encodec_loss = total_encodec_loss
|
414 |
+
accelerator.log(train_log, step=completed_steps)
|
415 |
+
|
416 |
+
if isinstance(checkpointing_steps, int):
|
417 |
+
if completed_steps % checkpointing_steps == 0:
|
418 |
+
output_dir = f"step_{completed_steps }"
|
419 |
+
if args.output_dir is not None:
|
420 |
+
output_dir = os.path.join(
|
421 |
+
args.output_dir, "checkpoints", output_dir
|
422 |
+
)
|
423 |
+
accelerator.save_state(output_dir)
|
424 |
+
|
425 |
+
if completed_steps >= args.max_train_steps:
|
426 |
+
break
|
427 |
+
|
428 |
+
model.eval()
|
429 |
+
gen_kwargs = {
|
430 |
+
"max_length": args.val_max_target_length,
|
431 |
+
}
|
432 |
+
predictions = []
|
433 |
+
references = []
|
434 |
+
eval_iterator = tqdm(
|
435 |
+
eval_dataloader,
|
436 |
+
desc="Validation",
|
437 |
+
disable=not accelerator.is_local_main_process,
|
438 |
+
dynamic_ncols=True,
|
439 |
+
colour="MAGENTA",
|
440 |
+
)
|
441 |
+
for step, batch in enumerate(eval_iterator):
|
442 |
+
# Drop the padded samples of the last batch of dataloader
|
443 |
+
# try:
|
444 |
+
# if accelerator.gradient_state.end_of_dataloader and accelerator.gradient_state.remainder > 0:
|
445 |
+
# batch = batch[:accelerator.gradient_state.remainder]
|
446 |
+
# except:
|
447 |
+
# pass
|
448 |
+
|
449 |
+
with torch.no_grad():
|
450 |
+
batch["input_ids"] = batch["input_ids"].cuda()
|
451 |
+
batch["clap"] = batch["clap"].cuda()
|
452 |
+
batch["attention_mask"] = batch["attention_mask"].cuda()
|
453 |
+
batch["eos_mask"] = batch["eos_mask"].cuda()
|
454 |
+
|
455 |
+
generated_tokens = accelerator.unwrap_model(model).generate(
|
456 |
+
batch["input_ids"],
|
457 |
+
clap=batch["clap"],
|
458 |
+
attention_mask=batch["attention_mask"],
|
459 |
+
eos_mask=batch["eos_mask"],
|
460 |
+
**gen_kwargs,
|
461 |
+
)
|
462 |
+
|
463 |
+
generated_tokens = accelerator.pad_across_processes(
|
464 |
+
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
|
465 |
+
)
|
466 |
+
generated_tokens = generated_tokens.cpu().numpy()
|
467 |
+
captions = batch["captions"]
|
468 |
+
|
469 |
+
if isinstance(generated_tokens, tuple):
|
470 |
+
generated_tokens = generated_tokens[0]
|
471 |
+
decoded_preds = tokenizer.batch_decode(
|
472 |
+
generated_tokens, skip_special_tokens=True
|
473 |
+
)
|
474 |
+
|
475 |
+
predictions.extend(decoded_preds)
|
476 |
+
references.extend(captions)
|
477 |
+
|
478 |
+
logger.info("Evaluating predictions...")
|
479 |
+
result = evaluate(predictions, references, metrics=metric_list)
|
480 |
+
|
481 |
+
# Gather Result
|
482 |
+
result = {k: v.cuda() for k, v in result[0].items()}
|
483 |
+
result = accelerator.gather_for_metrics(result)
|
484 |
+
# Log the average of metrics among the processes
|
485 |
+
if accelerator.num_processes > 1:
|
486 |
+
result = {f"eval/{k}": round(v.mean().item(), 4) for k, v in result.items()}
|
487 |
+
else:
|
488 |
+
result = {f"eval/{k}": round(v.item(), 4) for k, v in result.items()}
|
489 |
+
logger.info(result)
|
490 |
+
|
491 |
+
if args.with_tracking:
|
492 |
+
result["train/epoch_train_loss"] = (total_loss - before_epoch_loss) / len(
|
493 |
+
train_dataloader
|
494 |
+
)
|
495 |
+
result["train/steps"] = completed_steps
|
496 |
+
before_epoch_loss = total_loss
|
497 |
+
if args.encodec_masking_prob > 0:
|
498 |
+
result["train/epoch_encodec_loss"] = (
|
499 |
+
total_encodec_loss - before_epoch_encodec_loss
|
500 |
+
) / len(train_dataloader)
|
501 |
+
before_epoch_encodec_loss = total_encodec_loss
|
502 |
+
accelerator.log(result, step=epoch)
|
503 |
+
|
504 |
+
if args.checkpointing_steps == "epoch":
|
505 |
+
output_dir = f"epoch_{epoch}"
|
506 |
+
if args.output_dir is not None:
|
507 |
+
output_dir = os.path.join(args.output_dir, "checkpoints", output_dir)
|
508 |
+
accelerator.save_state(output_dir)
|
509 |
+
if accelerator.is_main_process:
|
510 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
511 |
+
unwrapped_model.config.save_pretrained(output_dir)
|
512 |
+
|
513 |
+
if args.output_dir is not None:
|
514 |
+
save_dir = os.path.join(args.output_dir, "final")
|
515 |
+
accelerator.wait_for_everyone()
|
516 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
517 |
+
unwrapped_model.save_pretrained(
|
518 |
+
save_dir,
|
519 |
+
is_main_process=accelerator.is_main_process,
|
520 |
+
save_function=accelerator.save,
|
521 |
+
)
|
522 |
+
if accelerator.is_main_process:
|
523 |
+
tokenizer.save_pretrained(save_dir)
|
524 |
+
|
525 |
+
|
526 |
+
if __name__ == "__main__":
|
527 |
+
main()
|
train.sh
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
CFG_PATH="cfg/clotho/base.yaml"
|
2 |
+
accelerate launch --multi_gpu --main_process_port=1200 train.py $CFG_PATH
|