tonyswoo commited on
Commit
73baeae
1 Parent(s): 80d82df

Initial Commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpt/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Enclap
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.16.0
8
- app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Enclap
3
+ emoji: 🔊
4
+ colorFrom: pink
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.41.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
+ license: openrail
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
cfg/audiocaps/base.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /output
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: csv/audiocaps/train.csv
7
+ validation_file: csv/audiocaps/valid.csv
8
+ encodec_base_path: /data/audiocaps/encodec
9
+ clap_base_path: /data/audiocaps/clap
10
+ tokenizer_name: facebook/bart-base
11
+ config_name_or_path: facebook/bart-base
12
+ model_name_or_path: facebook/bart-base
13
+ eval_num_captions: 5
14
+ overwrite_output_dir: False
15
+
16
+ # Basic Config
17
+ encodec_masking_prob: 0.15
18
+ encodec_masking_span: 10
19
+ num_train_epochs: 15
20
+ max_train_steps: null
21
+ gradient_accumulation_steps: 1
22
+ per_device_train_batch_size: 16
23
+ per_device_eval_batch_size: 16
24
+ split_batches: true
25
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
26
+ resume_from_checkpoint: null
27
+
28
+ # Generation Config
29
+ max_target_length: 128
30
+ val_max_target_length: 50
31
+
32
+ # Training Hyperparameters
33
+ # "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
34
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
35
+ lr_scheduler_type: inverse_sqrt
36
+ learning_rate: 6.5e-5 # peak lr
37
+ num_warmup_steps: 2000
38
+ weight_decay: 0.01
39
+ max_grad_norm: 1.0
40
+
41
+ # Others
42
+ with_tracking: true
43
+ report_to: tensorboard
44
+ ignore_pad_token_for_loss: true
45
+ preprocessing_num_workers: 32
46
+ use_slow_tokenizer: false
47
+ overwrite_cache: false
48
+ pad_to_max_length: false
cfg/audiocaps/large.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /output
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: csv/audiocaps/train.csv
7
+ validation_file: csv/audiocaps/valid.csv
8
+ encodec_base_path: /data/audiocaps/encodec
9
+ clap_base_path: /data/audiocaps/clap
10
+ tokenizer_name: facebook/bart-large
11
+ config_name_or_path: facebook/bart-large
12
+ model_name_or_path: facebook/bart-large
13
+ eval_num_captions: 5
14
+ overwrite_output_dir: False
15
+
16
+ # Basic Config
17
+ encodec_masking_prob: 0.15
18
+ encodec_masking_span: 10
19
+ num_train_epochs: 15
20
+ max_train_steps: null
21
+ gradient_accumulation_steps: 1
22
+ per_device_train_batch_size: 64
23
+ per_device_eval_batch_size: 64
24
+ split_batches: true
25
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
26
+ resume_from_checkpoint: null
27
+
28
+ # Generation Config
29
+ max_target_length: 128
30
+ val_max_target_length: 50
31
+
32
+ # Training Hyperparameters
33
+ # "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
34
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
35
+ lr_scheduler_type: inverse_sqrt
36
+ learning_rate: 3e-5 # peak lr
37
+ num_warmup_steps: 2000
38
+ weight_decay: 0.01
39
+ max_grad_norm: 1.0
40
+
41
+ # Others
42
+ with_tracking: true
43
+ report_to: tensorboard
44
+ ignore_pad_token_for_loss: true
45
+ preprocessing_num_workers: 32
46
+ use_slow_tokenizer: false
47
+ overwrite_cache: false
48
+ pad_to_max_length: false
cfg/audiocaps_args.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /data/jyk/aac_results/bart_large/audiocaps_3e5_gpu4_1115_2000
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: /workspace/audiobart/csv/AudioCaps/train.csv
7
+ validation_file: /workspace/audiobart/csv/AudioCaps/val.csv
8
+ test_file: /workspace/audiobart/csv/AudioCaps/test.csv
9
+ base_path: /data/jyk/aac_dataset/AudioCaps/encodec_16
10
+ clap_base_path: /data/jyk/aac_dataset/AudioCaps/clap_audio_fused
11
+ tokenizer_name: facebook/bart-large
12
+ # model_name_or_path: /workspace/audiobart/bart/model
13
+ model_name_or_path: facebook/bart-large
14
+ num_captions: 5
15
+ overwrite_output_dir: False
16
+
17
+
18
+ # Training Configs
19
+ # Basic Config
20
+ max_encodec_length: 1022
21
+ only_encoder_epochs: 0
22
+ only_encodec_epochs: 0
23
+ clap_masking_prob: -1
24
+ encodec_masking_prob: 0.15
25
+ encodec_masking_length: 10
26
+ random_sampling: true
27
+ num_train_epochs: 30
28
+ max_train_steps: null
29
+ gradient_accumulation_steps: 1
30
+ per_device_train_batch_size: 64
31
+ per_device_eval_batch_size: 64
32
+ split_batches: true
33
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
34
+ resume_from_checkpoint: null
35
+
36
+ # Model & Generation Config
37
+ max_source_length: 1024
38
+ max_target_length: 128
39
+ val_max_target_length: 50
40
+ num_beams: null
41
+ pad_to_max_length: false
42
+ num_subsampling: 0
43
+
44
+ # Training Hyperparameters
45
+ learning_rate: 3e-5 # peak lr
46
+ # Should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
47
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
48
+ lr_scheduler_type: inverse_sqrt
49
+ # lr_scheduler_type: two_stage_inverse_sqrt
50
+ weight_decay: 0.01
51
+ num_warmup_steps: 2000
52
+ max_grad_norm: 1.0
53
+
54
+ # Do not Change
55
+ with_tracking: true
56
+ report_to: all
57
+ ignore_pad_token_for_loss: true
58
+ preprocessing_num_workers: 32
59
+ use_slow_tokenizer: false
60
+ overwrite_cache: false
cfg/clotho/base.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /output
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: /csv/clotho/train.csv
7
+ validation_file: /csv/clotho/valid.csv
8
+ encodec_base_path: /data/clotho/encodec
9
+ clap_base_path: /data/clotho/clap
10
+ tokenizer_name: facebook/bart-base
11
+ config_name_or_path: facebook/bart-base
12
+ model_name_or_path: facebook/bart-base
13
+ eval_num_captions: 5
14
+ overwrite_output_dir: False
15
+
16
+ # Basic Config
17
+ encodec_masking_prob: 0.15
18
+ encodec_masking_span: 10
19
+ num_train_epochs: 15
20
+ max_train_steps: null
21
+ gradient_accumulation_steps: 1
22
+ per_device_train_batch_size: 64
23
+ per_device_eval_batch_size: 64
24
+ split_batches: true
25
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
26
+ resume_from_checkpoint: null
27
+
28
+ # Generation Config
29
+ max_target_length: 128
30
+ val_max_target_length: 50
31
+
32
+ # Training Hyperparameters
33
+ # "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
34
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
35
+ lr_scheduler_type: inverse_sqrt
36
+ learning_rate: 4e-5 # peak lr
37
+ num_warmup_steps: 1000
38
+ weight_decay: 0.01
39
+ max_grad_norm: 1.0
40
+
41
+ # Others
42
+ with_tracking: true
43
+ report_to: tensorboard
44
+ ignore_pad_token_for_loss: true
45
+ preprocessing_num_workers: 32
46
+ use_slow_tokenizer: false
47
+ overwrite_cache: false
48
+ pad_to_max_length: false
cfg/clotho/large.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /output
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: /csv/clotho/train.csv
7
+ validation_file: /csv/clotho/valid.csv
8
+ encodec_base_path: /data/clotho/encodec
9
+ clap_base_path: /data/clotho/clap
10
+ tokenizer_name: facebook/bart-large
11
+ config_name_or_path: facebook/bart-large
12
+ model_name_or_path: facebook/bart-large
13
+ eval_num_captions: 5
14
+ overwrite_output_dir: False
15
+
16
+ # Basic Config
17
+ encodec_masking_prob: 0.15
18
+ encodec_masking_span: 10
19
+ num_train_epochs: 15
20
+ max_train_steps: null
21
+ gradient_accumulation_steps: 1
22
+ per_device_train_batch_size: 64
23
+ per_device_eval_batch_size: 64
24
+ split_batches: true
25
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
26
+ resume_from_checkpoint: null
27
+
28
+ # Generation Config
29
+ max_target_length: 128
30
+ val_max_target_length: 50
31
+
32
+ # Training Hyperparameters
33
+ # "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
34
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
35
+ lr_scheduler_type: inverse_sqrt
36
+ learning_rate: 2.5e-5 # peak lr
37
+ num_warmup_steps: 1000
38
+ weight_decay: 0.01
39
+ max_grad_norm: 1.0
40
+
41
+ # Others
42
+ with_tracking: true
43
+ report_to: tensorboard
44
+ ignore_pad_token_for_loss: true
45
+ preprocessing_num_workers: 32
46
+ use_slow_tokenizer: false
47
+ overwrite_cache: false
48
+ pad_to_max_length: false
cfg/clotho_finetune/base.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /output
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: /csv/clotho/train.csv
7
+ validation_file: /csv/clotho/valid.csv
8
+ encodec_base_path: /data/clotho/encodec
9
+ clap_base_path: /data/clotho/clap
10
+ tokenizer_name: facebook/bart-base
11
+ config_name_or_path: facebook/bart-base
12
+ model_name_or_path: /data/enclap_audiocaps
13
+ eval_num_captions: 5
14
+ overwrite_output_dir: False
15
+
16
+ # Basic Config
17
+ encodec_masking_prob: 0.15
18
+ encodec_masking_span: 10
19
+ num_train_epochs: 15
20
+ max_train_steps: null
21
+ gradient_accumulation_steps: 1
22
+ per_device_train_batch_size: 64
23
+ per_device_eval_batch_size: 64
24
+ split_batches: true
25
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
26
+ resume_from_checkpoint: null
27
+
28
+ # Generation Config
29
+ max_target_length: 128
30
+ val_max_target_length: 50
31
+
32
+ # Training Hyperparameters
33
+ # "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
34
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
35
+ lr_scheduler_type: inverse_sqrt
36
+ learning_rate: 2e-5 # peak lr
37
+ num_warmup_steps: 1000
38
+ weight_decay: 0.01
39
+ max_grad_norm: 1.0
40
+
41
+ # Others
42
+ with_tracking: true
43
+ report_to: tensorboard
44
+ ignore_pad_token_for_loss: true
45
+ preprocessing_num_workers: 32
46
+ use_slow_tokenizer: false
47
+ overwrite_cache: false
48
+ pad_to_max_length: false
cfg/clotho_finetune/large.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experiment Config for each experiment
2
+ output_dir: /output
3
+ logging_dir: runs/tb_log
4
+ logging_steps: 10
5
+ seed: 1115
6
+ train_file: /csv/clotho/train.csv
7
+ validation_file: /csv/clotho/valid.csv
8
+ encodec_base_path: /data/clotho/encodec
9
+ clap_base_path: /data/clotho/clap
10
+ tokenizer_name: facebook/bart-large
11
+ config_name_or_path: facebook/bart-large
12
+ model_name_or_path: /data/enclap_audiocaps
13
+ eval_num_captions: 5
14
+ overwrite_output_dir: False
15
+
16
+ # Basic Config
17
+ encodec_masking_prob: 0.15
18
+ encodec_masking_span: 10
19
+ num_train_epochs: 15
20
+ max_train_steps: null
21
+ gradient_accumulation_steps: 1
22
+ per_device_train_batch_size: 64
23
+ per_device_eval_batch_size: 64
24
+ split_batches: true
25
+ checkpointing_steps: epoch # 'epoch' to save for each epoch, or number of steps
26
+ resume_from_checkpoint: null
27
+
28
+ # Generation Config
29
+ max_target_length: 128
30
+ val_max_target_length: 50
31
+
32
+ # Training Hyperparameters
33
+ # "lr_schedulre_type" should be one of "linear", "cosine", "cosine_with_restarts", "polynomial",
34
+ # "constant", "constant_with_warmpup", "inverse_sqrt", "reduce_lr_on_plateau", "two_stage_inverse_sqrt"
35
+ lr_scheduler_type: inverse_sqrt
36
+ learning_rate: 1.25e-5 # peak lr
37
+ num_warmup_steps: 1000
38
+ weight_decay: 0.01
39
+ max_grad_norm: 1.0
40
+
41
+ # Others
42
+ with_tracking: true
43
+ report_to: tensorboard
44
+ ignore_pad_token_for_loss: true
45
+ preprocessing_num_workers: 32
46
+ use_slow_tokenizer: false
47
+ overwrite_cache: false
48
+ pad_to_max_length: false
ckpt/config.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/bart-base",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "BartModel"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.1,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 768,
15
+ "decoder_attention_heads": 12,
16
+ "decoder_ffn_dim": 3072,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 6,
19
+ "decoder_start_token_id": 2,
20
+ "dropout": 0.1,
21
+ "early_stopping": true,
22
+ "encoder_attention_heads": 12,
23
+ "encoder_ffn_dim": 3072,
24
+ "encoder_layerdrop": 0.0,
25
+ "encoder_layers": 6,
26
+ "eos_token_id": 2,
27
+ "forced_bos_token_id": 0,
28
+ "forced_eos_token_id": 2,
29
+ "gradient_checkpointing": false,
30
+ "id2label": {
31
+ "0": "LABEL_0",
32
+ "1": "LABEL_1",
33
+ "2": "LABEL_2"
34
+ },
35
+ "init_std": 0.02,
36
+ "is_encoder_decoder": true,
37
+ "label2id": {
38
+ "LABEL_0": 0,
39
+ "LABEL_1": 1,
40
+ "LABEL_2": 2
41
+ },
42
+ "max_position_embeddings": 1024,
43
+ "model_type": "bart",
44
+ "no_repeat_ngram_size": 3,
45
+ "normalize_before": false,
46
+ "normalize_embedding": true,
47
+ "num_beams": 4,
48
+ "num_hidden_layers": 6,
49
+ "pad_token_id": 1,
50
+ "scale_embedding": false,
51
+ "task_specific_params": {
52
+ "summarization": {
53
+ "length_penalty": 1.0,
54
+ "max_length": 128,
55
+ "min_length": 12,
56
+ "num_beams": 4
57
+ },
58
+ "summarization_cnn": {
59
+ "length_penalty": 2.0,
60
+ "max_length": 142,
61
+ "min_length": 56,
62
+ "num_beams": 4
63
+ },
64
+ "summarization_xsum": {
65
+ "length_penalty": 1.0,
66
+ "max_length": 62,
67
+ "min_length": 11,
68
+ "num_beams": 6
69
+ }
70
+ },
71
+ "torch_dtype": "float32",
72
+ "transformers_version": "4.29.0",
73
+ "use_cache": true,
74
+ "vocab_size": 50265
75
+ }
ckpt/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fef9cc30fdf47b82a8bb846d418ac3ef893b4d10e909fafbbe3ed8a1931cf23
3
+ size 663433954
csv/audiocaps/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv/audiocaps/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv/audiocaps/valid.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv/clotho/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv/clotho/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv/clotho/valid.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/__init__.py ADDED
File without changes
data/collator.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from transformers import BatchEncoding, DataCollatorForSeq2Seq
5
+
6
+
7
+ @dataclass
8
+ class DataCollatorForEnClapBart(DataCollatorForSeq2Seq):
9
+ input_pad_token_id: int = 1024
10
+ num_rvq: int = 16
11
+
12
+ def __call__(self, features, return_tensors=None):
13
+ if return_tensors is None:
14
+ return_tensors = self.return_tensors
15
+
16
+ batch_size = len(features)
17
+ # stacked_features = {k: [f[k] for f in features] for k in features[0]}
18
+ clap_embedding = torch.Tensor(
19
+ [feature["clap_embedding"] for feature in features]
20
+ )
21
+
22
+ pad_token_id = self.tokenizer.pad_token_id
23
+ self.tokenizer.pad_token_id = self.input_pad_token_id
24
+ keys = ["input_ids", "mcm_labels"]
25
+ tmp_key_map = {"input_ids": "input_ids", "mcm_labels": "labels"}
26
+ input_features = super().__call__(
27
+ [
28
+ {tmp_key_map[key]: feature[key][:, i] for key in keys}
29
+ for feature in features
30
+ for i in range(feature[keys[0]].shape[-1])
31
+ ],
32
+ return_tensors,
33
+ )
34
+
35
+ self.tokenizer.pad_token_id = 1
36
+ keys = ["encodec_mask", "attention_mask", "labels"]
37
+ tmp_key_map = {
38
+ "encodec_mask": "input_ids",
39
+ "attention_mask": "attention_mask",
40
+ "labels": "labels",
41
+ }
42
+ other_features = super().__call__(
43
+ [{tmp_key_map[key]: feature[key] for key in keys} for feature in features],
44
+ return_tensors,
45
+ )
46
+ self.tokenizer.pad_token_id = pad_token_id
47
+
48
+ return BatchEncoding(
49
+ {
50
+ "input_ids": input_features["input_ids"]
51
+ .reshape(batch_size, self.num_rvq, -1)
52
+ .transpose(1, 2),
53
+ "mcm_labels": input_features["labels"]
54
+ .reshape(batch_size, self.num_rvq, -1)
55
+ .transpose(1, 2),
56
+ "attention_mask": other_features["attention_mask"],
57
+ "encodec_mask": other_features["input_ids"],
58
+ "labels": other_features["labels"],
59
+ "clap_embedding": clap_embedding,
60
+ }
61
+ )
data/infer_clap.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from laion_clap import CLAP_Module
8
+ from tqdm import tqdm
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--data_path",
14
+ "-d",
15
+ required=True,
16
+ type=str,
17
+ help="Path of the original wav files",
18
+ )
19
+ parser.add_argument(
20
+ "--save_path",
21
+ "-s",
22
+ required=True,
23
+ type=str,
24
+ help="Path to save the clap audio embedding '.npy' files",
25
+ )
26
+ parser.add_argument(
27
+ "--clap_ckpt",
28
+ "-c",
29
+ required=True,
30
+ type=str,
31
+ help="Path of the pretrained clap checkpoint",
32
+ )
33
+ parser.add_argument(
34
+ "--enable_fusion",
35
+ "-e",
36
+ default=True,
37
+ type=bool,
38
+ help="Whether to enable the feature fusion of the clap model. Depends on the clap checkpoint you are using",
39
+ )
40
+ parser.add_argument(
41
+ "--audio_encoder",
42
+ "-a",
43
+ default="HTSAT-tiny",
44
+ type=str,
45
+ help="Audio encoder of the clap model. Depends on the clap checkpoint you are using",
46
+ )
47
+ args = parser.parse_args()
48
+
49
+ model = CLAP_Module(enable_fusion=args.enable_fusion, aencoder=args.audio_encoder)
50
+ model.load_ckpt(args.clap_ckpt)
51
+ data_path = Path(args.data_path)
52
+ save_path = Path(args.save_path)
53
+
54
+ with torch.no_grad():
55
+ for wav_path in tqdm(
56
+ data_path.glob("**/*.wav"), dynamic_ncols=True, colour="yellow"
57
+ ):
58
+ wav, _ = librosa.load(wav_path, sr=48000)
59
+
60
+ clap_embeding = model.get_audio_embedding_from_data(
61
+ x=wav[np.newaxis], use_tensor=False
62
+ )
63
+ clap_embeding = clap_embeding.squeeze(axis=0)
64
+
65
+ out_path = save_path / wav_path.with_suffix(".npy").relative_to(data_path)
66
+ out_path.parent.mkdir(exist_ok=True)
67
+ np.save(out_path, clap_embeding)
data/infer_encodec.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+ from encodec import EncodecModel
8
+ from encodec.utils import convert_audio
9
+ from tqdm import tqdm
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument(
14
+ "--data_path", type=str, required=True, help="Path of the original wav files"
15
+ )
16
+ parser.add_argument(
17
+ "--save_path", type=str, required=True, help="Path to save encodec .npy files"
18
+ )
19
+ args = parser.parse_args()
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = EncodecModel.encodec_model_24khz()
23
+ model.set_target_bandwidth(12.0)
24
+ model = model.to(device)
25
+
26
+ data_path = Path(args.data_path)
27
+ save_path = Path(args.save_path)
28
+
29
+ with torch.no_grad():
30
+ for wav_path in tqdm(data_path.glob("**/*.wav")):
31
+ wav, sr = torchaudio.load(wav_path)
32
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
33
+ wav = wav.unsqueeze(0).to(device)
34
+ encoded_frames = model.encode(wav)
35
+
36
+ codes = torch.cat([codebook for codebook, _ in encoded_frames], dim=-1)
37
+ codes = codes.cpu().squeeze(0).transpose(-1, -2).detach().numpy()
38
+
39
+ out_path = save_path / wav_path.with_suffix(".npy").relative_to(data_path)
40
+ out_path.parent.mkdir(exist_ok=True)
41
+ np.save(out_path, codes)
data/preprocess.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from random import randint
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ from transformers import BartTokenizerFast
9
+
10
+
11
+ @dataclass
12
+ class Preprocessor:
13
+ encodec_base_path: Path
14
+ clap_base_path: Path
15
+ tokenizer: BartTokenizerFast = BartTokenizerFast.from_pretrained(
16
+ "facebook/bart-base"
17
+ )
18
+ max_length: int = 1024
19
+ mcm_masking_prob: float = 0.15
20
+ mcm_masking_span: int = 10
21
+ label_pad_token_id: int = -100
22
+ mask_token_id: int = 1024
23
+ num_eval_captions: int = 5
24
+
25
+ def __post_init__(self):
26
+ if isinstance(self.encodec_base_path, str):
27
+ self.encodec_base_path = Path(self.encodec_base_path)
28
+ if isinstance(self.clap_base_path, str):
29
+ self.clap_base_path = Path(self.clap_base_path)
30
+ if isinstance(self.tokenizer, str):
31
+ self.tokenizer = BartTokenizerFast.from_pretrained(self.tokenizer)
32
+
33
+ def preprocess_train(self, example):
34
+ path = example["file_path"]
35
+ encodec = np.load(self.encodec_base_path / path)
36
+ clap_embedding = np.load(self.clap_base_path / path)
37
+ encodec_mask = np.array(
38
+ [0, 0] + [1] * min(encodec.shape[0], self.max_length - 3) + [0]
39
+ )
40
+ attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype(
41
+ np.int64
42
+ )
43
+ target_text = self.tokenizer(text_target=example["caption"])
44
+
45
+ if encodec.shape[0] + 3 > self.max_length:
46
+ start = randint(0, encodec.shape[0] - self.max_length + 3)
47
+ encodec = encodec[start : start + self.max_length - 3]
48
+
49
+ mcm_labels = None
50
+ if self.mcm_masking_prob > 0:
51
+ num_rvq = encodec.shape[-1]
52
+ mcm_mask, _ = _compute_mask_indices(
53
+ encodec.T.shape, self.mcm_masking_prob, self.mcm_masking_span
54
+ )
55
+ mcm_mask = mcm_mask.T
56
+ mcm_labels = np.where(mcm_mask, encodec, self.label_pad_token_id)
57
+ mcm_labels = np.concatenate(
58
+ [
59
+ np.ones((2, num_rvq), dtype=np.int64) * self.label_pad_token_id,
60
+ mcm_labels,
61
+ np.ones((1, num_rvq), dtype=np.int64) * self.label_pad_token_id,
62
+ ],
63
+ axis=0,
64
+ )
65
+ encodec[mcm_mask] = self.mask_token_id
66
+
67
+ encodec = np.concatenate(
68
+ [
69
+ np.ones((2, num_rvq), dtype=np.int64) * self.tokenizer.bos_token_id,
70
+ encodec,
71
+ np.ones((1, num_rvq), dtype=np.int64) * self.tokenizer.eos_token_id,
72
+ ],
73
+ axis=0,
74
+ )
75
+
76
+ return {
77
+ "input_ids": encodec,
78
+ "clap_embedding": clap_embedding,
79
+ "encodec_mask": encodec_mask,
80
+ "attention_mask": attention_mask,
81
+ "mcm_labels": mcm_labels,
82
+ "labels": target_text["input_ids"],
83
+ }
84
+
85
+ def preprocess_eval(self, example):
86
+ path = example["file_path"]
87
+ encodec = np.load(self.encodec_base_path / path)
88
+ clap_embedding = np.load(self.clap_base_path / path)
89
+ attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype(
90
+ np.int64
91
+ )
92
+
93
+ if encodec.shape[0] + 3 > self.max_length:
94
+ encodec = encodec[: self.max_length - 3]
95
+
96
+ captions = []
97
+ for i in range(self.num_eval_captions):
98
+ captions.append(example[f"caption_{i+1}"])
99
+
100
+ return {
101
+ "input_ids": encodec,
102
+ "attention_mask": attention_mask,
103
+ "clap": clap_embedding,
104
+ "captions": captions,
105
+ }
106
+
107
+
108
+ def _compute_mask_indices(
109
+ shape: Tuple[int, int],
110
+ mask_prob: float,
111
+ mask_length: int,
112
+ attention_mask: Optional[torch.LongTensor] = None,
113
+ min_masks: int = 0,
114
+ ) -> np.ndarray:
115
+ """
116
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
117
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
118
+ CPU as part of the preprocessing during training.
119
+
120
+ Args:
121
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
122
+ the first element is the batch size and the second element is the length of the axis to span.
123
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
124
+ independently generated mask spans of length `mask_length` is computed by
125
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
126
+ actual percentage will be smaller.
127
+ mask_length: size of the mask
128
+ min_masks: minimum number of masked spans
129
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
130
+ each batch dimension.
131
+ """
132
+ batch_size, sequence_length = shape
133
+
134
+ if mask_length < 1:
135
+ raise ValueError("`mask_length` has to be bigger than 0.")
136
+
137
+ if mask_length > sequence_length:
138
+ raise ValueError(
139
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
140
+ f" and `sequence_length`: {sequence_length}`"
141
+ )
142
+
143
+ # epsilon is used for probabilistic rounding
144
+ epsilon = np.random.rand(1).item()
145
+
146
+ def compute_num_masked_span(input_length):
147
+ """Given input length, compute how many spans should be masked"""
148
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
149
+ num_masked_span = max(num_masked_span, min_masks)
150
+
151
+ # make sure num masked span <= sequence_length
152
+ if num_masked_span * mask_length > sequence_length:
153
+ num_masked_span = sequence_length // mask_length
154
+
155
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
156
+ if input_length - (mask_length - 1) < num_masked_span:
157
+ num_masked_span = max(input_length - (mask_length - 1), 0)
158
+
159
+ return num_masked_span
160
+
161
+ # compute number of masked spans in batch
162
+ input_lengths = (
163
+ attention_mask.sum(-1).detach().tolist()
164
+ if attention_mask is not None
165
+ else [sequence_length for _ in range(batch_size)]
166
+ )
167
+
168
+ # SpecAugment mask to fill
169
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
170
+ spec_aug_mask_idxs = []
171
+
172
+ max_num_masked_span = compute_num_masked_span(sequence_length)
173
+
174
+ if max_num_masked_span == 0:
175
+ return spec_aug_mask
176
+
177
+ for input_length in input_lengths:
178
+ # compute num of masked spans for this input
179
+ num_masked_span = compute_num_masked_span(input_length)
180
+
181
+ # get random indices to mask
182
+ spec_aug_mask_idx = np.random.choice(
183
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
184
+ )
185
+
186
+ # pick first sampled index that will serve as a dummy index to pad vector
187
+ # to ensure same dimension for all batches due to probabilistic rounding
188
+ # Picking first sample just pads those vectors twice.
189
+ if len(spec_aug_mask_idx) == 0:
190
+ # this case can only happen if `input_length` is strictly smaller then
191
+ # `sequence_length` in which case the last token has to be a padding
192
+ # token which we can use as a dummy mask id
193
+ dummy_mask_idx = sequence_length - 1
194
+ else:
195
+ dummy_mask_idx = spec_aug_mask_idx[0]
196
+
197
+ spec_aug_mask_idx = np.concatenate(
198
+ [
199
+ spec_aug_mask_idx,
200
+ np.ones(max_num_masked_span - num_masked_span, dtype=np.int32)
201
+ * dummy_mask_idx,
202
+ ]
203
+ )
204
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
205
+
206
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
207
+
208
+ # expand masked indices to masked spans
209
+ spec_aug_mask_idxs = np.broadcast_to(
210
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
211
+ )
212
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(
213
+ batch_size, max_num_masked_span * mask_length
214
+ )
215
+
216
+ # add offset to the starting indexes so that indexes now create a span
217
+ offsets = np.arange(mask_length)[None, None, :]
218
+ offsets = np.broadcast_to(
219
+ offsets, (batch_size, max_num_masked_span, mask_length)
220
+ ).reshape(batch_size, max_num_masked_span * mask_length)
221
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
222
+
223
+ # ensure that we cannot have indices larger than sequence_length
224
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
225
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = (
226
+ sequence_length - 1
227
+ )
228
+
229
+ # scatter indices to mask
230
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
231
+
232
+ return torch.from_numpy(spec_aug_mask), spec_aug_mask_idxs
gradio_app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from transformers import AutoProcessor
8
+
9
+ from inference import EnClap
10
+
11
+
12
+ def input_toggle(choice: str):
13
+ if choice == "file":
14
+ return gr.update(visible=True), gr.update(visible=False)
15
+ return gr.update(visible=False), gr.update(visible=True)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ import logging
20
+
21
+ logging.getLogger().setLevel(logging.INFO)
22
+ ckpt_path = "./ckpt" # os.getenv("ckpt_path")
23
+ device = "cpu" # os.getenv("device")
24
+
25
+ enclap = EnClap(ckpt_path=ckpt_path, device=device)
26
+
27
+ def run_enclap(
28
+ input_type: str,
29
+ file_input: Tuple[int, np.ndarray],
30
+ mic_input: Tuple[int, np.ndarray],
31
+ seed: int,
32
+ ) -> str:
33
+ print(input_type, file_input, mic_input)
34
+ input = file_input if input_type == "file" else mic_input
35
+ if input is None:
36
+ raise gr.Error("Input audio was not provided.")
37
+ res, audio = input
38
+ torch.manual_seed(seed)
39
+ return enclap.infer_from_audio(torch.from_numpy(audio), res)[0]
40
+
41
+ with gr.Blocks() as demo:
42
+ with gr.Row():
43
+ with gr.Column():
44
+ radio = gr.Radio(
45
+ ["file", "mic"],
46
+ value="file",
47
+ label="Choose the input method of the audio.",
48
+ )
49
+ file = gr.Audio(label="Input", visible=True)
50
+ mic = gr.Mic(label="Input", visible=False)
51
+ slider = gr.Slider(minimum=0, maximum=100, label="Seed")
52
+ radio.change(fn=input_toggle, inputs=radio, outputs=[file, mic])
53
+ button = gr.Button("Run", label="run")
54
+ with gr.Column():
55
+ output = gr.Text(label="Output")
56
+ button.click(
57
+ fn=run_enclap, inputs=[radio, file, mic, slider], outputs=output
58
+ )
59
+
60
+ demo.launch()
inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from encodec import EncodecModel
7
+ from encodec.utils import convert_audio
8
+ from laion_clap import CLAP_Module
9
+ from transformers import AutoTokenizer
10
+
11
+ from modeling.enclap_bart import EnClapBartConfig, EnClapBartForConditionalGeneration
12
+
13
+
14
+ class EnClap:
15
+ def __init__(
16
+ self,
17
+ ckpt_path: str,
18
+ clap_audio_model: str = "HTSAT-tiny",
19
+ clap_enable_fusion = True,
20
+ device: str = "cuda",
21
+ ):
22
+ config = EnClapBartConfig.from_pretrained(ckpt_path)
23
+ self.device = device
24
+ self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
25
+ self.model = (
26
+ EnClapBartForConditionalGeneration.from_pretrained(ckpt_path)
27
+ .to(self.device)
28
+ .eval()
29
+ )
30
+
31
+ self.encodec = EncodecModel.encodec_model_24khz().to(self.device)
32
+ self.encodec.set_target_bandwidth(12.0)
33
+
34
+ self.clap_model = CLAP_Module(enable_fusion=clap_enable_fusion, amodel=clap_audio_model, device=self.device)
35
+ self.clap_model.load_ckpt()
36
+
37
+ self.generation_config = {
38
+ "_from_model_config": True,
39
+ "bos_token_id": 0,
40
+ "decoder_start_token_id": 2,
41
+ "early_stopping": True,
42
+ "eos_token_id": 2,
43
+ "forced_bos_token_id": 0,
44
+ "forced_eos_token_id": 2,
45
+ "no_repeat_ngram_size": 3,
46
+ "num_beams": 4,
47
+ "pad_token_id": 1,
48
+ "max_length": 50,
49
+ }
50
+ self.scale_factor = 2**15
51
+ self.max_seq_len = config.max_position_embeddings - 3
52
+
53
+ @torch.no_grad()
54
+ def infer_from_audio_file(
55
+ self, audio_file: str, generation_config: Dict[str, Any] = None
56
+ ) -> str:
57
+ if generation_config is None:
58
+ generation_config = self.generation_config
59
+ audio, res = torchaudio.load(audio_file)
60
+ return self.infer_from_audio(audio[0], res)
61
+
62
+ @torch.no_grad()
63
+ def infer_from_audio(
64
+ self, audio: torch.Tensor, res: int, generation_config: Dict[str, Any] = None
65
+ ) -> str:
66
+ if generation_config is None:
67
+ generation_config = self.generation_config
68
+ if audio.dtype == torch.int or audio.dtype == torch.short:
69
+ audio = audio / self.scale_factor
70
+ encodec_audio = (
71
+ convert_audio(
72
+ audio.unsqueeze(0), res, self.encodec.sample_rate, self.encodec.channels
73
+ )
74
+ .unsqueeze(0)
75
+ .to(self.device)
76
+ )
77
+ encodec_frames = self.encodec.encode(encodec_audio)
78
+ encodec_frames = torch.cat(
79
+ [codebook for codebook, _ in encodec_frames], dim=-1
80
+ ).mT
81
+
82
+ clap_audio = torchaudio.transforms.Resample(res, 48000)(audio).unsqueeze(0)
83
+ clap_embedding = self.clap_model.get_audio_embedding_from_data(clap_audio, use_tensor=True)
84
+
85
+ return self._infer(encodec_frames, clap_embedding, generation_config)
86
+
87
+ @torch.no_grad()
88
+ def _infer(
89
+ self,
90
+ encodec_frames: torch.LongTensor,
91
+ clap_embedding: torch.Tensor,
92
+ generation_config: Dict[str, Any] = None,
93
+ ) -> str:
94
+ input_ids = torch.cat(
95
+ [
96
+ torch.ones(
97
+ (encodec_frames.shape[0], 2, encodec_frames.shape[-1]),
98
+ dtype=torch.long,
99
+ ).to(self.device)
100
+ * self.tokenizer.bos_token_id,
101
+ encodec_frames[:, : self.max_seq_len],
102
+ torch.ones(
103
+ (encodec_frames.shape[0], 1, encodec_frames.shape[-1]),
104
+ dtype=torch.long,
105
+ ).to(self.device)
106
+ * self.tokenizer.eos_token_id,
107
+ ],
108
+ dim=1,
109
+ )
110
+ encodec_mask = torch.LongTensor(
111
+ [[0, 0] + [1] * (input_ids.shape[1] - 3) + [0]]
112
+ ).to(self.device)
113
+
114
+ enclap_bart_inputs = {
115
+ "input_ids": input_ids,
116
+ "encodec_mask": encodec_mask,
117
+ "clap_embedding": clap_embedding,
118
+ }
119
+
120
+ results = self.model.generate(**enclap_bart_inputs, **generation_config)
121
+ caption = self.tokenizer.batch_decode(results, skip_special_tokens=True)
122
+
123
+ return caption
124
+
125
+ @torch.no_grad()
126
+ def infer_from_encodec(
127
+ self,
128
+ file_path,
129
+ clap_path: str = "clap",
130
+ generation_config: Dict[str, Any] = None,
131
+ ):
132
+ if generation_config is None:
133
+ generation_config = self.generation_config
134
+ input_ids = np.load(file_path)
135
+ if input_ids.shape[0] > self.max_encodec_length:
136
+ input_ids = input_ids[: self.max_encodec_length, :]
137
+ input_length = input_ids.shape[0]
138
+ input_ids = np.concatenate([input_ids, self.eos_padding], axis=0)
139
+ input_ids = torch.LongTensor(input_ids)
140
+ input_ids = input_ids.unsqueeze(0).to(self.device)
141
+ attention_mask = (
142
+ torch.ones(input_length + 3, dtype=torch.int64).unsqueeze(0).to(self.device)
143
+ )
144
+ eos_mask = [0] * (input_length + 3)
145
+ eos_mask[input_length + 2] = 1
146
+ eos_mask = torch.BoolTensor(eos_mask).unsqueeze(0)
147
+ # Load CLAP
148
+ clap_path = file_path.replace("encodec_16", clap_path)
149
+ clap = np.load(clap_path)
150
+ clap = torch.Tensor(clap).unsqueeze(0).to(self.device)
151
+ input = {
152
+ "input_ids": input_ids,
153
+ "clap": clap,
154
+ "attention_mask": attention_mask,
155
+ "eos_mask": eos_mask,
156
+ }
157
+
158
+ generated_ids = self.model.generate(**input, **generation_config)
159
+ text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
160
+
161
+ return text
metric/__init__.py ADDED
File without changes
metric/compute_metric.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from aac_metrics import evaluate
3
+ import copy
4
+ metric_list = ["bleu_1", "bleu_4", "rouge_l", "meteor", "spider_fl"]
5
+
6
+ if __name__=='__main__':
7
+ csv_path = "/workspace/audiobart/csv/predictions/prediction_clap.csv"
8
+ df = pd.read_csv(csv_path)
9
+
10
+ predictions = []
11
+ references = []
12
+ for idx in range(len(df)):
13
+ predictions.append(df.loc[idx]['prediction'])
14
+ reference = [df.loc[idx]['caption_1'],df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5'] ]
15
+ references.append(reference)
16
+
17
+ print("> Evaluating predictions...")
18
+ result = evaluate(predictions, references, metrics=metric_list)
19
+ result = {k: v.item() for k, v in result[0].items()}
20
+ keys = list(result.keys())
21
+ for key in keys:
22
+ if "fluerr" in key:
23
+ del result[key]
24
+ print(result)
metric/compute_metric_from_scratch.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+ sys.path.append('.')
4
+
5
+ from aac_metrics import evaluate
6
+ from inference import AudioBartInference
7
+ from tqdm import tqdm
8
+ import os
9
+ import pandas as pd
10
+
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12
+ metric_list = ["bleu_1", "bleu_4", "rouge_l", "meteor", "spider_fl"]
13
+
14
+ if __name__ == "__main__":
15
+ dataset = "AudioCaps"
16
+ # dataset = "clotho"
17
+ ckpt_path = "/data/jyk/aac_results/bart_base/audiocaps_35e5_2000/checkpoints/epoch_8"
18
+
19
+ # ckpt_path = "/data/jyk/aac_results/masking/linear_scalinEg/checkpoints/epoch_14"
20
+ max_encodec_length = 1022
21
+ infer_module = AudioBartInference(ckpt_path, max_encodec_length)
22
+ from_encodec = True
23
+ csv_path = f"/workspace/audiobart/csv/{dataset}/test.csv"
24
+ base_path = f"/data/jyk/aac_dataset/{dataset}/encodec_16"
25
+ clap_name = "clap_audio_fused"
26
+ df = pd.read_csv(csv_path)
27
+
28
+ generation_config = {
29
+ "_from_model_config": True,
30
+ "bos_token_id": 0,
31
+ "decoder_start_token_id": 2,
32
+ "early_stopping": True,
33
+ "eos_token_id": 2,
34
+ "forced_bos_token_id": 0,
35
+ "forced_eos_token_id": 2,
36
+ "no_repeat_ngram_size": 3,
37
+ "num_beams": 4,
38
+ "pad_token_id": 1,
39
+ "max_length": 50
40
+ }
41
+
42
+ print(f"> Making Predictions for model {ckpt_path}...")
43
+ predictions = []
44
+ references = []
45
+ for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="BLUE"):
46
+ if not from_encodec:
47
+ wav_path = df.loc[idx]['file_name']
48
+ else:
49
+ wav_path = df.loc[idx]['file_path']
50
+ wav_path = os.path.join(base_path,wav_path)
51
+ if not os.path.exists(wav_path):
52
+ pass
53
+
54
+ if not from_encodec:
55
+ prediction = infer_module.infer(wav_path)
56
+ else:
57
+ prediction = infer_module.infer_from_encodec(wav_path, clap_name, generation_config)
58
+
59
+ predictions.append(prediction[0])
60
+ reference = [df.loc[idx]['caption_1'],df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5'] ]
61
+ references.append(reference)
62
+
63
+ print("> Evaluating predictions...")
64
+ result = evaluate(predictions, references, metrics=metric_list)
65
+ result = {k: round(v.item(),4) for k, v in result[0].items()}
66
+ keys = list(result.keys())
67
+ for key in keys:
68
+ if "fluerr" in key:
69
+ del result[key]
70
+ print(result)
metric/make_predictions.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('..')
3
+
4
+ from inference import AudioBartInference
5
+ from tqdm import tqdm
6
+ import os
7
+ import pandas as pd
8
+ import csv
9
+ os.environ["CUDA_VISIBLE_DEVICES"] = "5"
10
+
11
+
12
+ if __name__ == "__main__":
13
+ ckpt_path = "/data/jyk/aac_results/clap/clap/checkpoints/epoch_12"
14
+ infer_module = AudioBartInference(ckpt_path)
15
+ from_encodec = True
16
+ csv_path = "/workspace/audiobart/csv/test.csv"
17
+ base_path = "/data/jyk/aac_dataset/clotho/encodec"
18
+ df = pd.read_csv(csv_path)
19
+ save_path = "/workspace/audiobart/csv/predictions/prediction_clap.csv"
20
+ f = open(save_path, 'w', newline='')
21
+ writer = csv.writer(f)
22
+ writer.writerow(['file_path', 'prediction', 'caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5'])
23
+
24
+ print(f"> Making Predictions for model {ckpt_path}...")
25
+ for idx in tqdm(range(len(df)), dynamic_ncols=True, colour="red"):
26
+ if not from_encodec:
27
+ wav_path = df.loc[idx]['file_name']
28
+ else:
29
+ wav_path = df.loc[idx]['file_path']
30
+ wav_path = os.path.join(base_path,wav_path)
31
+ if not os.path.exists(wav_path):
32
+ pass
33
+
34
+ if not from_encodec:
35
+ prediction = infer_module.infer(wav_path)
36
+ else:
37
+ prediction = infer_module.infer_from_encodec(wav_path)
38
+ line = [wav_path, prediction[0], df.loc[idx]['caption_1'], df.loc[idx]['caption_2'],df.loc[idx]['caption_3'],df.loc[idx]['caption_4'],df.loc[idx]['caption_5']]
39
+ writer.writerow(line)
40
+
41
+ f.close()
modeling/__init__.py ADDED
File without changes
modeling/enclap_bart.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ Seq2SeqLMOutput,
11
+ Seq2SeqModelOutput,
12
+ )
13
+ from transformers.models.bart.configuration_bart import BartConfig
14
+ from transformers.models.bart.modeling_bart import (
15
+ BartDecoder,
16
+ BartEncoderLayer,
17
+ BartForConditionalGeneration,
18
+ BartLearnedPositionalEmbedding,
19
+ BartModel,
20
+ BartPretrainedModel,
21
+ _expand_mask,
22
+ shift_tokens_right,
23
+ )
24
+ from transformers.utils import logging
25
+
26
+ from .modeling_outputs import EnClapBartOutput
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class EnClapBartConfig(BartConfig):
32
+ def __init__(
33
+ self,
34
+ d_clap: int = 512,
35
+ num_rvq: int = 16,
36
+ encodec_vocab_size: int = 1024,
37
+ encodec_pad_token_id: int = 1024,
38
+ mcm_loss_scale: float = 0.7,
39
+ label_smoothing: float = 0.2,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.d_clap = d_clap
44
+ self.num_rvq = num_rvq
45
+ self.encodec_vocab_size = encodec_vocab_size
46
+ self.encodec_pad_token_id = encodec_pad_token_id
47
+ self.mcm_loss_scale = mcm_loss_scale
48
+ self.label_smoothing = label_smoothing
49
+
50
+
51
+ class EnClapBartEncoder(BartPretrainedModel):
52
+ """
53
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
54
+ [`BartEncoderLayer`].
55
+
56
+ Args:
57
+ config: BartConfig
58
+ embed_tokens (nn.Embedding): output embedding
59
+ """
60
+
61
+ def __init__(
62
+ self, config: EnClapBartConfig, embed_tokens: Optional[nn.Embedding] = None
63
+ ):
64
+ super().__init__(config)
65
+
66
+ self.dropout = config.dropout
67
+ self.layerdrop = config.encoder_layerdrop
68
+
69
+ clap_dim = config.d_clap
70
+ embed_dim = config.d_model
71
+ self.padding_idx = config.pad_token_id
72
+ self.max_source_positions = config.max_position_embeddings
73
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
74
+
75
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
76
+
77
+ if embed_tokens is not None:
78
+ self.embed_tokens.weight = embed_tokens.weight
79
+
80
+ self.embed_encodec = nn.ModuleList(
81
+ [
82
+ nn.Embedding(
83
+ math.ceil((config.encodec_vocab_size + 1) / 64) * 64,
84
+ config.d_model,
85
+ padding_idx=config.encodec_pad_token_id,
86
+ )
87
+ for _ in range(config.num_rvq)
88
+ ]
89
+ )
90
+
91
+ self.clap_projection = nn.Linear(clap_dim, embed_dim)
92
+
93
+ self.embed_positions = BartLearnedPositionalEmbedding(
94
+ config.max_position_embeddings,
95
+ embed_dim,
96
+ )
97
+ self.layers = nn.ModuleList(
98
+ [BartEncoderLayer(config) for _ in range(config.encoder_layers)]
99
+ )
100
+ self.layernorm_embedding = nn.LayerNorm(embed_dim)
101
+
102
+ self.gradient_checkpointing = False
103
+ # Initialize weights and apply final processing
104
+ self.post_init()
105
+
106
+ def get_input_embeddings(self):
107
+ return self.embed_tokens
108
+
109
+ def set_input_embeddings(self, value):
110
+ self.embed_tokens = value
111
+
112
+ def forward(
113
+ self,
114
+ input_ids: torch.LongTensor = None,
115
+ clap_embedding: Optional[torch.Tensor] = None,
116
+ encodec_mask: Optional[torch.Tensor] = None,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ head_mask: Optional[torch.Tensor] = None,
119
+ inputs_embeds: Optional[torch.FloatTensor] = None,
120
+ output_attentions: Optional[bool] = None,
121
+ output_hidden_states: Optional[bool] = None,
122
+ return_dict: Optional[bool] = None,
123
+ ) -> Union[Tuple, BaseModelOutput]:
124
+ r"""
125
+ Args:
126
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
127
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
128
+ provide it.
129
+
130
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
131
+ [`PreTrainedTokenizer.__call__`] for details.
132
+
133
+ [What are input IDs?](../glossary#input-ids)
134
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
135
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
136
+
137
+ - 1 for tokens that are **not masked**,
138
+ - 0 for tokens that are **masked**.
139
+
140
+ [What are attention masks?](../glossary#attention-mask)
141
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
142
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
143
+
144
+ - 1 indicates the head is **not masked**,
145
+ - 0 indicates the head is **masked**.
146
+
147
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
148
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
149
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
150
+ than the model's internal embedding lookup matrix.
151
+ output_attentions (`bool`, *optional*):
152
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
153
+ returned tensors for more detail.
154
+ output_hidden_states (`bool`, *optional*):
155
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
156
+ for more detail.
157
+ return_dict (`bool`, *optional*):
158
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
159
+ """
160
+ output_attentions = (
161
+ output_attentions
162
+ if output_attentions is not None
163
+ else self.config.output_attentions
164
+ )
165
+ output_hidden_states = (
166
+ output_hidden_states
167
+ if output_hidden_states is not None
168
+ else self.config.output_hidden_states
169
+ )
170
+ return_dict = (
171
+ return_dict if return_dict is not None else self.config.use_return_dict
172
+ )
173
+
174
+ # retrieve input_ids and inputs_embeds
175
+ if input_ids is not None and inputs_embeds is not None:
176
+ raise ValueError(
177
+ "You cannot specify both input_ids and inputs_embeds at the same time"
178
+ )
179
+ elif input_ids is not None:
180
+ if input_ids.ndim == 2: # This is effectively just input = input_ids
181
+ input = input_ids
182
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
183
+ elif inputs_embeds is not None:
184
+ input = inputs_embeds[:, :, -1]
185
+ else:
186
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
187
+
188
+ if inputs_embeds is None:
189
+ if input_ids.ndim == 2:
190
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
191
+ elif input_ids.ndim == 3:
192
+ encodec_ids = torch.where(encodec_mask.unsqueeze(-1) > 0, input_ids, 0)
193
+ encodec_embeds = torch.zeros(
194
+ input_ids.shape[0], input_ids.shape[1], self.config.d_model
195
+ ).to(self.device)
196
+ for i, embed in enumerate(self.embed_encodec):
197
+ encodec_embeds = encodec_embeds + embed(encodec_ids[..., i])
198
+ bart_ids = torch.where(encodec_mask == 0, input_ids[..., 0], 0)
199
+ bart_embeds = self.embed_tokens(bart_ids)
200
+ input_embeds = torch.where(
201
+ encodec_mask.unsqueeze(-1) > 0, encodec_embeds, bart_embeds
202
+ )
203
+
204
+ # Get CLAP embedding
205
+ if clap_embedding is not None:
206
+ clap_embedding = self.clap_projection(clap_embedding)
207
+ input_embeds[:, 0] = clap_embedding
208
+ inputs_embeds = input_embeds.to(self.device)
209
+
210
+ batch_size = input_ids.size(0)
211
+ embed_pos = self.embed_positions(input_ids).to(self.device)
212
+ embed_pos = torch.cat(
213
+ [
214
+ torch.zeros(batch_size, 1, self.config.d_model).to(self.device),
215
+ embed_pos[:, :-1],
216
+ ],
217
+ dim=1,
218
+ )
219
+
220
+ hidden_states = inputs_embeds + embed_pos
221
+ hidden_states = self.layernorm_embedding(hidden_states)
222
+ hidden_states = nn.functional.dropout(
223
+ hidden_states, p=self.dropout, training=self.training
224
+ )
225
+
226
+ # expand attention_mask
227
+ if attention_mask is not None:
228
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
229
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
230
+
231
+ encoder_states = () if output_hidden_states else None
232
+ all_attentions = () if output_attentions else None
233
+
234
+ # check if head_mask has a correct number of layers specified if desired
235
+ if head_mask is not None:
236
+ if head_mask.size()[0] != (len(self.layers)):
237
+ raise ValueError(
238
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
239
+ f" {head_mask.size()[0]}."
240
+ )
241
+
242
+ for idx, encoder_layer in enumerate(self.layers):
243
+ if output_hidden_states:
244
+ encoder_states = encoder_states + (hidden_states,)
245
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
246
+ dropout_probability = random.uniform(0, 1)
247
+ if self.training and (
248
+ dropout_probability < self.layerdrop
249
+ ): # skip the layer
250
+ layer_outputs = (None, None)
251
+ else:
252
+ if self.gradient_checkpointing and self.training:
253
+
254
+ def create_custom_forward(module):
255
+ def custom_forward(*inputs):
256
+ return module(*inputs, output_attentions)
257
+
258
+ return custom_forward
259
+
260
+ layer_outputs = torch.utils.checkpoint.checkpoint(
261
+ create_custom_forward(encoder_layer),
262
+ hidden_states,
263
+ attention_mask,
264
+ (head_mask[idx] if head_mask is not None else None),
265
+ )
266
+ else:
267
+ layer_outputs = encoder_layer(
268
+ hidden_states,
269
+ attention_mask,
270
+ layer_head_mask=(
271
+ head_mask[idx] if head_mask is not None else None
272
+ ),
273
+ output_attentions=output_attentions,
274
+ )
275
+
276
+ hidden_states = layer_outputs[0]
277
+
278
+ if output_attentions:
279
+ all_attentions = all_attentions + (layer_outputs[1],)
280
+
281
+ if output_hidden_states:
282
+ encoder_states = encoder_states + (hidden_states,)
283
+
284
+ if not return_dict:
285
+ return tuple(
286
+ v
287
+ for v in [hidden_states, encoder_states, all_attentions]
288
+ if v is not None
289
+ )
290
+ return BaseModelOutput(
291
+ last_hidden_state=hidden_states,
292
+ hidden_states=encoder_states,
293
+ attentions=all_attentions,
294
+ )
295
+
296
+
297
+ class EnClapBartModel(BartModel):
298
+ def __init__(self, config: EnClapBartConfig):
299
+ super(BartModel, self).__init__(config)
300
+
301
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
302
+ self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
303
+
304
+ self.encoder = EnClapBartEncoder(config, self.shared)
305
+ self.decoder = BartDecoder(config, self.shared)
306
+
307
+ # Initialize weights and apply final processing
308
+ self.post_init()
309
+
310
+ def forward(
311
+ self,
312
+ input_ids: torch.LongTensor = None,
313
+ clap_embedding: Optional[torch.Tensor] = None,
314
+ encodec_mask: Optional[torch.Tensor] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ decoder_input_ids: Optional[torch.LongTensor] = None,
317
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
318
+ head_mask: Optional[torch.Tensor] = None,
319
+ decoder_head_mask: Optional[torch.Tensor] = None,
320
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
321
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
322
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
323
+ inputs_embeds: Optional[torch.FloatTensor] = None,
324
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
325
+ use_cache: Optional[bool] = None,
326
+ output_attentions: Optional[bool] = None,
327
+ output_hidden_states: Optional[bool] = None,
328
+ return_dict: Optional[bool] = None,
329
+ ) -> Union[Tuple, Seq2SeqModelOutput]:
330
+ # different to other models, Bart automatically creates decoder_input_ids from
331
+ # input_ids if no decoder_input_ids are provided
332
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
333
+ if input_ids is None:
334
+ raise ValueError(
335
+ "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
336
+ "passed, `input_ids` cannot be `None`. Please pass either "
337
+ "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
338
+ )
339
+
340
+ decoder_input_ids = shift_tokens_right(
341
+ input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
342
+ )
343
+
344
+ output_attentions = (
345
+ output_attentions
346
+ if output_attentions is not None
347
+ else self.config.output_attentions
348
+ )
349
+ output_hidden_states = (
350
+ output_hidden_states
351
+ if output_hidden_states is not None
352
+ else self.config.output_hidden_states
353
+ )
354
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
355
+ return_dict = (
356
+ return_dict if return_dict is not None else self.config.use_return_dict
357
+ )
358
+
359
+ if encoder_outputs is None:
360
+ encoder_outputs = self.encoder(
361
+ input_ids=input_ids,
362
+ clap_embedding=clap_embedding,
363
+ encodec_mask=encodec_mask,
364
+ attention_mask=attention_mask,
365
+ head_mask=head_mask,
366
+ inputs_embeds=inputs_embeds,
367
+ output_attentions=output_attentions,
368
+ output_hidden_states=output_hidden_states,
369
+ return_dict=return_dict,
370
+ )
371
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
372
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
373
+ encoder_outputs = BaseModelOutput(
374
+ last_hidden_state=encoder_outputs[0],
375
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
376
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
377
+ )
378
+
379
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
380
+ decoder_outputs = self.decoder(
381
+ input_ids=decoder_input_ids,
382
+ attention_mask=decoder_attention_mask,
383
+ encoder_hidden_states=encoder_outputs[0],
384
+ encoder_attention_mask=attention_mask,
385
+ head_mask=decoder_head_mask,
386
+ cross_attn_head_mask=cross_attn_head_mask,
387
+ past_key_values=past_key_values,
388
+ inputs_embeds=decoder_inputs_embeds,
389
+ use_cache=use_cache,
390
+ output_attentions=output_attentions,
391
+ output_hidden_states=output_hidden_states,
392
+ return_dict=return_dict,
393
+ )
394
+
395
+ if not return_dict:
396
+ return decoder_outputs + encoder_outputs
397
+
398
+ return Seq2SeqModelOutput(
399
+ last_hidden_state=decoder_outputs.last_hidden_state,
400
+ past_key_values=decoder_outputs.past_key_values,
401
+ decoder_hidden_states=decoder_outputs.hidden_states,
402
+ decoder_attentions=decoder_outputs.attentions,
403
+ cross_attentions=decoder_outputs.cross_attentions,
404
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
405
+ encoder_hidden_states=encoder_outputs.hidden_states,
406
+ encoder_attentions=encoder_outputs.attentions,
407
+ )
408
+
409
+
410
+ class EnClapBartForConditionalGeneration(BartForConditionalGeneration):
411
+ config_class = EnClapBartConfig
412
+
413
+ def __init__(self, config: EnClapBartConfig):
414
+ super(BartForConditionalGeneration, self).__init__(config)
415
+ self.model = EnClapBartModel(config)
416
+ self.register_buffer(
417
+ "final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))
418
+ )
419
+ self.lm_head = nn.Linear(
420
+ config.d_model, self.model.shared.num_embeddings, bias=False
421
+ )
422
+ self.mcm_heads = nn.ModuleList(
423
+ [
424
+ nn.Linear(config.d_model, config.encodec_vocab_size)
425
+ for _ in range(config.num_rvq)
426
+ ]
427
+ )
428
+
429
+ # Initialize weights and apply final processing
430
+ self.post_init()
431
+
432
+ def forward(
433
+ self,
434
+ input_ids: torch.LongTensor = None,
435
+ clap_embedding: Optional[torch.Tensor] = None,
436
+ encodec_mask: Optional[torch.Tensor] = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ decoder_input_ids: Optional[torch.LongTensor] = None,
439
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
440
+ head_mask: Optional[torch.Tensor] = None,
441
+ decoder_head_mask: Optional[torch.Tensor] = None,
442
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
443
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
444
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
445
+ inputs_embeds: Optional[torch.FloatTensor] = None,
446
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
447
+ labels: Optional[torch.LongTensor] = None,
448
+ mcm_labels: Optional[List[torch.LongTensor]] = None,
449
+ use_cache: Optional[bool] = None,
450
+ output_attentions: Optional[bool] = None,
451
+ output_hidden_states: Optional[bool] = None,
452
+ return_dict: Optional[bool] = None,
453
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
454
+ r"""
455
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
456
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
457
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
458
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
459
+
460
+ Returns:
461
+ """
462
+ return_dict = (
463
+ return_dict if return_dict is not None else self.config.use_return_dict
464
+ )
465
+
466
+ if labels is not None:
467
+ if use_cache:
468
+ logger.warning(
469
+ "The `use_cache` argument is changed to `False` since `labels` is provided."
470
+ )
471
+ use_cache = False
472
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
473
+ decoder_input_ids = shift_tokens_right(
474
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
475
+ )
476
+
477
+ outputs = self.model(
478
+ input_ids,
479
+ clap_embedding=clap_embedding,
480
+ encodec_mask=encodec_mask,
481
+ attention_mask=attention_mask,
482
+ decoder_input_ids=decoder_input_ids,
483
+ encoder_outputs=encoder_outputs,
484
+ decoder_attention_mask=decoder_attention_mask,
485
+ head_mask=head_mask,
486
+ decoder_head_mask=decoder_head_mask,
487
+ cross_attn_head_mask=cross_attn_head_mask,
488
+ past_key_values=past_key_values,
489
+ inputs_embeds=inputs_embeds,
490
+ decoder_inputs_embeds=decoder_inputs_embeds,
491
+ use_cache=use_cache,
492
+ output_attentions=output_attentions,
493
+ output_hidden_states=output_hidden_states,
494
+ return_dict=return_dict,
495
+ )
496
+
497
+ mcm_loss = None
498
+ if mcm_labels is not None:
499
+ mcm_loss = 0.0
500
+ loss_fct = CrossEntropyLoss()
501
+ for i, mcm_head in enumerate(self.mcm_heads):
502
+ mcm_logits = mcm_head(outputs.encoder_last_hidden_state)
503
+ loss_scale = 1 / 2 ** (i + 1)
504
+ loss = loss_fct(
505
+ mcm_logits.view(-1, self.config.encodec_vocab_size),
506
+ mcm_labels[..., i].reshape(-1),
507
+ )
508
+ mcm_loss = mcm_loss + loss * loss_scale
509
+
510
+ lm_logits = self.lm_head(outputs[0])
511
+ lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
512
+
513
+ masked_lm_loss = None
514
+ if labels is not None:
515
+ labels = labels.to(lm_logits.device)
516
+ loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
517
+ masked_lm_loss = loss_fct(
518
+ lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
519
+ )
520
+
521
+ loss = None
522
+ if mcm_loss is None:
523
+ loss = masked_lm_loss
524
+ elif masked_lm_loss is None:
525
+ loss = mcm_loss
526
+ else:
527
+ mcm_loss = mcm_loss * self.config.mcm_loss_scale
528
+ loss = masked_lm_loss + mcm_loss
529
+
530
+ if not return_dict:
531
+ output = (lm_logits,) + outputs[1:]
532
+ return (
533
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
534
+ )
535
+
536
+ return EnClapBartOutput(
537
+ loss=loss,
538
+ lm_loss=masked_lm_loss,
539
+ mcm_loss=mcm_loss,
540
+ logits=lm_logits,
541
+ past_key_values=outputs.past_key_values,
542
+ decoder_hidden_states=outputs.decoder_hidden_states,
543
+ decoder_attentions=outputs.decoder_attentions,
544
+ cross_attentions=outputs.cross_attentions,
545
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
546
+ encoder_hidden_states=outputs.encoder_hidden_states,
547
+ encoder_attentions=outputs.encoder_attentions,
548
+ )
modeling/modeling_outputs.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers.modeling_outputs import Seq2SeqLMOutput
6
+
7
+
8
+ @dataclass
9
+ class EnClapBartOutput(Seq2SeqLMOutput):
10
+ mcm_loss: Optional[torch.FloatTensor] = None
11
+ lm_loss: Optional[torch.FloatTensor] = None
port_weights.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from math import ceil
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--ckpt_path", "-c", type=str)
10
+ args = parser.parse_args()
11
+
12
+ weight_name_map = {
13
+ "model.encodec_embeddings": None,
14
+ "encodec_embeddings": "embed_encodec",
15
+ "encodec_mlm_head": "mcm_heads",
16
+ }
17
+
18
+ ckpt_path = Path(args.ckpt_path)
19
+ weight_file = ckpt_path / "pytorch_model.bin"
20
+ state_dict = torch.load(weight_file, map_location="cpu")
21
+ new_state_dict = {}
22
+ for key in state_dict:
23
+ new_key = key
24
+ for orig, repl in weight_name_map.items():
25
+ if repl is None:
26
+ if orig in new_key:
27
+ new_key = None
28
+ break
29
+ continue
30
+ new_key = new_key.replace(orig, repl)
31
+ if new_key:
32
+ new_state_dict[new_key] = state_dict[key]
33
+ for key in new_state_dict:
34
+ if "model.encoder.embed_encodec" in key:
35
+ dim = new_state_dict[key].shape[0]
36
+ new_weight = torch.normal(
37
+ 0, 1, (ceil(dim / 64) * 64, new_state_dict[key].shape[1])
38
+ )
39
+ new_weight[:dim] = new_state_dict[key]
40
+ new_state_dict[key] = new_weight
41
+ weight_file.rename(weight_file.with_suffix(".bin.bak"))
42
+ torch.save(new_state_dict, weight_file)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aac-metrics==0.4.2
2
+ accelerate==0.20.3
3
+ datasets==2.13.1
4
+ encodec==0.1.1
5
+ laion-clap==1.1.4
6
+ librosa==0.10.1
7
+ markupsafe==2.0.1
8
+ omegaconf==2.3.0
9
+ soundfile==0.12.1
10
+ tensorboard==2.13.0
11
+ tokenizers==0.13.3
12
+ torch==1.13.0
13
+ torchaudio==0.13.0
14
+ transformers==4.29.0
test/bart_test.ipynb ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from transformers import AutoTokenizer\n",
19
+ "from transformers.models.bart.modeling_bart import BartForConditionalGeneration"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large\")"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 35,
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "data": {
38
+ "text/plain": [
39
+ "('bart/tokenizer/tokenizer_config.json',\n",
40
+ " 'bart/tokenizer/special_tokens_map.json',\n",
41
+ " 'bart/tokenizer/vocab.json',\n",
42
+ " 'bart/tokenizer/merges.txt',\n",
43
+ " 'bart/tokenizer/added_tokens.json',\n",
44
+ " 'bart/tokenizer/tokenizer.json')"
45
+ ]
46
+ },
47
+ "execution_count": 35,
48
+ "metadata": {},
49
+ "output_type": "execute_result"
50
+ }
51
+ ],
52
+ "source": [
53
+ "tokenizer.save_pretrained(\"bart/tokenizer\")"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 18,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "model = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large\", forced_bos_token_id=0)"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 4,
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "name": "stderr",
72
+ "output_type": "stream",
73
+ "text": [
74
+ "Some weights of AudioBartForConditionalGeneration were not initialized from the model checkpoint at bart/model/ and are newly initialized: ['model.encodec_embeddings.3.weight', 'model.encodec_embeddings.4.weight', 'model.encodec_embeddings.1.weight', 'model.encodec_embeddings.0.weight', 'model.encoder.encodec_embeddings.7.weight', 'model.encodec_embeddings.2.weight', 'model.encodec_embeddings.6.weight', 'model.encoder.encodec_embeddings.0.weight', 'model.encodec_embeddings.7.weight', 'model.encoder.encodec_embeddings.4.weight', 'model.encoder.encodec_embeddings.2.weight', 'model.encoder.encodec_embeddings.3.weight', 'model.encodec_embeddings.5.weight', 'model.encoder.encodec_embeddings.5.weight', 'model.encoder.encodec_embeddings.1.weight', 'model.encoder.encodec_embeddings.6.weight']\n",
75
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
76
+ ]
77
+ }
78
+ ],
79
+ "source": [
80
+ "from modeling.audiobart import AudioBartForConditionalGeneration\n",
81
+ "model = AudioBartForConditionalGeneration.from_pretrained(\"bart/model/\")"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 5,
87
+ "metadata": {},
88
+ "outputs": [
89
+ {
90
+ "name": "stdout",
91
+ "output_type": "stream",
92
+ "text": [
93
+ "{'input_ids': tensor([[ 0, 31414, 127, 50264, 32440, 3807, 118, 32440, 3807, 118,\n",
94
+ " 25610, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n"
95
+ ]
96
+ }
97
+ ],
98
+ "source": [
99
+ "text = \"Hello my <mask> yeppi yeppi yo\"\n",
100
+ "input = tokenizer(text, return_tensors='pt')\n",
101
+ "print(input)"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 8,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "generated_ids = model.generate(input[\"input_ids\"])"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 33,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "ids = output.logits.detach().numpy().argmax(-1)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 11,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "data": {
129
+ "text/plain": [
130
+ "['Hello my friends, yeppi yeppiiyeppiyeppii ye']"
131
+ ]
132
+ },
133
+ "execution_count": 11,
134
+ "metadata": {},
135
+ "output_type": "execute_result"
136
+ }
137
+ ],
138
+ "source": [
139
+ "tokenizer.batch_decode(generated_ids, skip_special_tokens=True)"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 36,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "model.save_pretrained(\"bart/model\")"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 51,
154
+ "metadata": {},
155
+ "outputs": [
156
+ {
157
+ "name": "stdout",
158
+ "output_type": "stream",
159
+ "text": [
160
+ "{'input_ids': tensor([[ 0, 7842, 330, 506, 1536, 267, 131, 6634, 36807, 571,\n",
161
+ " 20920, 127, 766, 16, 32440, 3807, 118, 32440, 3807, 118,\n",
162
+ " 25610, 2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "text = \"adskfalsj;lsdfg Hello my name is yeppi yeppi yo\"\n",
168
+ "input = tokenizer(text, return_tensors='pt')\n",
169
+ "print(input)"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 52,
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "output = model.forward(**input)"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 54,
184
+ "metadata": {},
185
+ "outputs": [
186
+ {
187
+ "name": "stderr",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "/opt/conda/lib/python3.9/site-packages/transformers/generation/utils.py:1353: UserWarning: Using `max_length`'s default (20) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
191
+ " warnings.warn(\n"
192
+ ]
193
+ },
194
+ {
195
+ "data": {
196
+ "text/plain": [
197
+ "['</s><s>adskfalsj;lsdfg Hello my name is yeppi ye</s>']"
198
+ ]
199
+ },
200
+ "execution_count": 54,
201
+ "metadata": {},
202
+ "output_type": "execute_result"
203
+ }
204
+ ],
205
+ "source": [
206
+ "tokenizer.batch_decode(model.generate(input['input_ids']))"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 39,
212
+ "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "data": {
216
+ "text/plain": [
217
+ "['<s>Hello my name is yeppi yeppi yo</s>']"
218
+ ]
219
+ },
220
+ "execution_count": 39,
221
+ "metadata": {},
222
+ "output_type": "execute_result"
223
+ }
224
+ ],
225
+ "source": [
226
+ "tokenizer.batch_decode(input['input_ids'])"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 45,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "from transformers.models.bart.modeling_bart import shift_tokens_right"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 48,
241
+ "metadata": {},
242
+ "outputs": [
243
+ {
244
+ "name": "stdout",
245
+ "output_type": "stream",
246
+ "text": [
247
+ "tensor([[ 0, 0, 31414, 127, 766, 16, 32440, 3807, 118, 32440,\n",
248
+ " 3807, 118, 25610]])\n"
249
+ ]
250
+ }
251
+ ],
252
+ "source": [
253
+ "print(shift_tokens_right(input['input_ids'], pad_token_id=1, decoder_start_token_id=0))"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 55,
259
+ "metadata": {},
260
+ "outputs": [
261
+ {
262
+ "name": "stderr",
263
+ "output_type": "stream",
264
+ "text": [
265
+ "Downloading (…)lve/main/config.json: 100%|██████████| 1.58k/1.58k [00:00<00:00, 589kB/s]\n",
266
+ "Downloading (…)olve/main/vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 1.29MB/s]\n",
267
+ "Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 884kB/s]\n",
268
+ "Downloading (…)/main/tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 7.43MB/s]\n"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "cnn_tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large-cnn\")"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": 63,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "original_text = \"ArithmeticErrorThe tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.\""
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 68,
288
+ "metadata": {},
289
+ "outputs": [
290
+ {
291
+ "name": "stdout",
292
+ "output_type": "stream",
293
+ "text": [
294
+ "['<s>ArithmeticErrorThe tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.</s>']\n"
295
+ ]
296
+ }
297
+ ],
298
+ "source": [
299
+ "input = cnn_tokenizer(text=original_text, return_tensors='pt')\n",
300
+ "print(cnn_tokenizer.batch_decode(input['input_ids']))"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": 65,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "cnn_model = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 66,
315
+ "metadata": {},
316
+ "outputs": [
317
+ {
318
+ "name": "stderr",
319
+ "output_type": "stream",
320
+ "text": [
321
+ "/opt/conda/lib/python3.9/site-packages/transformers/generation/utils.py:1353: UserWarning: Using `max_length`'s default (142) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
322
+ " warnings.warn(\n"
323
+ ]
324
+ },
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "['</s><s>The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world. It was the first structure to reach a height of 300 metres.</s>']"
329
+ ]
330
+ },
331
+ "execution_count": 66,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
334
+ }
335
+ ],
336
+ "source": [
337
+ "cnn_tokenizer.batch_decode(cnn_model.generate(input['input_ids']))"
338
+ ]
339
+ }
340
+ ],
341
+ "metadata": {
342
+ "kernelspec": {
343
+ "display_name": "base",
344
+ "language": "python",
345
+ "name": "python3"
346
+ },
347
+ "language_info": {
348
+ "codemirror_mode": {
349
+ "name": "ipython",
350
+ "version": 3
351
+ },
352
+ "file_extension": ".py",
353
+ "mimetype": "text/x-python",
354
+ "name": "python",
355
+ "nbconvert_exporter": "python",
356
+ "pygments_lexer": "ipython3",
357
+ "version": "3.9.12"
358
+ },
359
+ "orig_nbformat": 4
360
+ },
361
+ "nbformat": 4,
362
+ "nbformat_minor": 2
363
+ }
test/clap_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
test/dataset_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
test/dataset_test.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+ sys.path.append("..")
4
+
5
+ from datasets import load_dataset
6
+ from transformers import AutoTokenizer
7
+ from modeling.audiobart import AudioBartForConditionalGeneration
8
+ from torch.utils.data import DataLoader
9
+ from data.collator import EncodecCollator
10
+
11
+ import numpy as np
12
+ import torch
13
+ import os
14
+
15
+ if __name__=="__main__":
16
+ model = AudioBartForConditionalGeneration.from_pretrained('bart/model')
17
+ base_path = "/data/jyk/aac_dataset/AudioCaps/encodec_16/"
18
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')
19
+ data_files = {"train": "csv/AudioCaps/train.csv"}
20
+ max_encodec_length = 1021
21
+ clap_base_path = "/data/jyk/aac_dataset/AudioCaps/clap"
22
+
23
+ raw_dataset = load_dataset("csv", data_files=data_files)
24
+
25
+ def preprocess_function(example):
26
+ path = example['file_path']
27
+ encodec = np.load(os.path.join(base_path, path))
28
+ if encodec.shape[0]>max_encodec_length:
29
+ encodec = encodec[:max_encodec_length, :]
30
+ clap = np.load(os.path.join(clap_base_path, path))
31
+ attention_mask = np.ones(encodec.shape[0]+3).astype(np.int64)
32
+ target_text = tokenizer(text_target=example['caption'])
33
+
34
+ return {'input_ids': encodec, 'clap': clap, 'attention_mask': attention_mask, 'labels': target_text['input_ids'], 'decoder_attention_mask': target_text['attention_mask']}
35
+
36
+ train_dataset = raw_dataset['train'].map(preprocess_function)
37
+ train_dataset.set_format("pt", columns=['input_ids', 'attention_mask', 'clap', 'labels', 'decoder_attention_mask'])
38
+
39
+ train_data_collator = EncodecCollator(
40
+ tokenizer=tokenizer,
41
+ model=model,
42
+ return_tensors="pt",
43
+ random_sampling=False,
44
+ max_length=max_encodec_length,
45
+ num_subsampling=0,
46
+ clap_masking_prob=-1,
47
+ encodec_masking_prob=0.15,
48
+ encodec_masking_length=10
49
+ )
50
+
51
+ train_dataloader = DataLoader(
52
+ train_dataset, shuffle=True, collate_fn=train_data_collator, batch_size=16)
53
+
54
+ for idx, batch in enumerate(train_dataloader):
55
+ # output = model.generate(**batch, max_length=100)
56
+ output = model(**batch)
57
+ print(output)
test/encodec_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
test/encodec_test.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encodec import EncodecModel
2
+ from encodec.utils import convert_audio
3
+
4
+ import torchaudio
5
+ import torch
6
+
7
+ # Instantiate a pretrained EnCodec model
8
+ model = EncodecModel.encodec_model_24khz()
9
+ # The number of codebooks used will be determined bythe bandwidth selected.
10
+ # E.g. for a bandwidth of 6kbps, `n_q = 8` codebooks are used.
11
+ # Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8) and 12 kbps (n_q =16) and 24kbps (n_q=32).
12
+ # For the 48 kHz model, only 3, 6, 12, and 24 kbps are supported. The number
13
+ # of codebooks for each is half that of the 24 kHz model as the frame rate is twice as much.
14
+ model.set_target_bandwidth(6.0)
15
+
16
+ # Load and pre-process the audio waveform
17
+ wav, sr = torchaudio.load("<PATH_TO_AUDIO_FILE>")
18
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
19
+ wav = wav.unsqueeze(0)
20
+
21
+ # Extract discrete codes from EnCodec
22
+ with torch.no_grad():
23
+ encoded_frames = model.encode(wav)
24
+ codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
test/eval_dataset_test.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer
3
+ from modeling.audiobart import AudioBartForConditionalGeneration
4
+ from torch.utils.data import DataLoader
5
+ from data.collator import EncodecCollator
6
+
7
+ import numpy as np
8
+ import torch
9
+ import os
10
+
11
+ if __name__=="__main__":
12
+ model = AudioBartForConditionalGeneration.from_pretrained('bart/model')
13
+ basepath = "/data/jyk/aac_dataset/clotho/encodec/"
14
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')
15
+ data_files = {"validation": "csv/valid_allcaps.csv"}
16
+ num_captions = 5
17
+
18
+ raw_dataset = load_dataset("csv", data_files=data_files)
19
+
20
+ def preprocess_eval(example):
21
+ path = example['file_path']
22
+ encodec = np.load(os.path.join(basepath, path))
23
+ if encodec.shape[0]>1022:
24
+ encodec = encodec[:1022, :]
25
+ attention_mask = np.ones(encodec.shape[0]+2).astype(np.int64)
26
+ captions = []
27
+ for i in range(1, num_captions+1):
28
+ captions.append(example['caption_'+str(i)])
29
+
30
+ return {'input_ids': encodec, 'attention_mask': attention_mask, 'captions': captions}
31
+
32
+ train_dataset = raw_dataset['validation'].map(preprocess_eval)
33
+ train_dataset.set_format('pt', columns=['input_ids', 'attention_mask'], output_all_columns=True)
34
+ # train_dataset.remove_columns('file_path', 'caption_1', 'caption_2', 'caption_3', 'caption_4', 'caption_5')
35
+ data_collator = EncodecCollator(tokenizer=tokenizer, model=model, return_tensors="pt")
36
+
37
+ train_dataloader = DataLoader(
38
+ train_dataset, shuffle=True, collate_fn=data_collator, batch_size=16)
39
+
40
+ for idx, batch in enumerate(train_dataloader):
41
+ output = model.generate(**batch, max_length=100)
42
+ print(output)
test/masking_test.ipynb ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import torch"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "probability_matrix = torch.full((8, 15), 0.15)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 4,
33
+ "metadata": {},
34
+ "outputs": [
35
+ {
36
+ "data": {
37
+ "text/plain": [
38
+ "torch.Size([8, 15])"
39
+ ]
40
+ },
41
+ "execution_count": 4,
42
+ "metadata": {},
43
+ "output_type": "execute_result"
44
+ }
45
+ ],
46
+ "source": [
47
+ "probability_matrix.shape"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 9,
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "data": {
57
+ "text/plain": [
58
+ "tensor([[0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
59
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
60
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
61
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
62
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
63
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
64
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
65
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
66
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
67
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
68
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
69
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
70
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
71
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500],\n",
72
+ " [0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500,\n",
73
+ " 0.1500, 0.1500, 0.1500, 0.1500, 0.1500, 0.1500]])"
74
+ ]
75
+ },
76
+ "execution_count": 9,
77
+ "metadata": {},
78
+ "output_type": "execute_result"
79
+ }
80
+ ],
81
+ "source": [
82
+ "probability_matrix.masked_fill_(torch.tensor(0, dtype=torch.bool), value=0.0)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "masked_indices = torch.bernoulli()"
92
+ ]
93
+ }
94
+ ],
95
+ "metadata": {
96
+ "kernelspec": {
97
+ "display_name": "base",
98
+ "language": "python",
99
+ "name": "python3"
100
+ },
101
+ "language_info": {
102
+ "codemirror_mode": {
103
+ "name": "ipython",
104
+ "version": 3
105
+ },
106
+ "file_extension": ".py",
107
+ "mimetype": "text/x-python",
108
+ "name": "python",
109
+ "nbconvert_exporter": "python",
110
+ "pygments_lexer": "ipython3",
111
+ "version": "3.9.12"
112
+ },
113
+ "orig_nbformat": 4
114
+ },
115
+ "nbformat": 4,
116
+ "nbformat_minor": 2
117
+ }
test/metric_test.ipynb ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 8,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import evaluate"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 9,
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "[nltk_data] Downloading package wordnet to /root/nltk_data...\n",
22
+ "[nltk_data] Package wordnet is already up-to-date!\n",
23
+ "[nltk_data] Downloading package punkt to /root/nltk_data...\n",
24
+ "[nltk_data] Package punkt is already up-to-date!\n",
25
+ "[nltk_data] Downloading package omw-1.4 to /root/nltk_data...\n",
26
+ "[nltk_data] Package omw-1.4 is already up-to-date!\n"
27
+ ]
28
+ }
29
+ ],
30
+ "source": [
31
+ "metric = evaluate.load(\"meteor\")"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 6,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "import pandas as pd\n",
41
+ "df = pd.read_csv(\"csv/predictions.csv\")"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 7,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "predictions = []\n",
51
+ "references = []\n",
52
+ "for idx in range(len(df)):\n",
53
+ " predictions.append(df.loc[idx]['prediction'])\n",
54
+ " reference = [df.loc[idx]['caption1'],df.loc[idx]['caption2'],df.loc[idx]['caption3'],df.loc[idx]['caption4'],df.loc[idx]['caption5'] ]\n",
55
+ " references.append(reference)"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 8,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "from aac_metrics import evaluate\n",
65
+ "corpus_score = evaluate(predictions, references)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 4,
71
+ "metadata": {},
72
+ "outputs": [
73
+ {
74
+ "data": {
75
+ "text/plain": [
76
+ "{'bleu_1': tensor(0.3913, dtype=torch.float64),\n",
77
+ " 'bleu_2': tensor(0.1931, dtype=torch.float64),\n",
78
+ " 'bleu_3': tensor(0.1065, dtype=torch.float64),\n",
79
+ " 'bleu_4': tensor(0.0569, dtype=torch.float64),\n",
80
+ " 'meteor': tensor(0.1197, dtype=torch.float64),\n",
81
+ " 'rouge_l': tensor(0.2745, dtype=torch.float64),\n",
82
+ " 'cider_d': tensor(0.1235, dtype=torch.float64),\n",
83
+ " 'spice': tensor(0.0670, dtype=torch.float64),\n",
84
+ " 'spider': tensor(0.0953, dtype=torch.float64)}"
85
+ ]
86
+ },
87
+ "execution_count": 4,
88
+ "metadata": {},
89
+ "output_type": "execute_result"
90
+ }
91
+ ],
92
+ "source": [
93
+ "corpus_score[0]"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 8,
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "name": "stdout",
103
+ "output_type": "stream",
104
+ "text": [
105
+ "{'bleu_1': 0.3912776883574468, 'bleu_2': 0.19312066269135236, 'bleu_3': 0.10651188216812753, 'bleu_4': 0.05690269475018141, 'meteor': 0.11968742992878356, 'rouge_l': 0.2744644068893943, 'cider_d': 0.12347016800968286, 'spice': 0.06704068138550699, 'spider': 0.09525542469759493}\n"
106
+ ]
107
+ }
108
+ ],
109
+ "source": [
110
+ "print({k: v.item() for k, v in corpus_score[0].items()})"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 13,
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "results = metric.compute(predictions=predictions, references=references)"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 14,
125
+ "metadata": {},
126
+ "outputs": [
127
+ {
128
+ "name": "stdout",
129
+ "output_type": "stream",
130
+ "text": [
131
+ "{'meteor': 0.26686702985116983}\n"
132
+ ]
133
+ }
134
+ ],
135
+ "source": [
136
+ "print(results)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 8,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "bleu = evaluate.load(\"bleu\")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 9,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "from transformers import AutoTokenizer\n",
155
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-large\")"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": 11,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "bleu_result = bleu.compute(predictions=predictions, references=references, max_order=4)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": 12,
170
+ "metadata": {},
171
+ "outputs": [
172
+ {
173
+ "name": "stdout",
174
+ "output_type": "stream",
175
+ "text": [
176
+ "{'bleu': 0.06128958043343902, 'precisions': [0.42544588056899413, 0.09036238675413934, 0.031210136916404455, 0.01176031360836289], 'brevity_penalty': 1.0, 'length_ratio': 1.3508583690987124, 'translation_length': 13849, 'reference_length': 10252}\n"
177
+ ]
178
+ }
179
+ ],
180
+ "source": [
181
+ "print(bleu_result)"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 5,
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "ename": "AttributeError",
191
+ "evalue": "'function' object has no attribute 'load'",
192
+ "output_type": "error",
193
+ "traceback": [
194
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
195
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
196
+ "\u001b[1;32m/workspace/audiobart/metric_test.ipynb Cell 13\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell://attached-container%2B7b22636f6e7461696e65724e616d65223a222f617564696f62617274222c2273657474696e6773223a7b22686f7374223a227373683a2f2f3138332e3131302e36322e3639227d7d/workspace/audiobart/metric_test.ipynb#X14sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m rouge_metric \u001b[39m=\u001b[39m evaluate\u001b[39m.\u001b[39;49mload(\u001b[39m\"\u001b[39m\u001b[39mrouge\u001b[39m\u001b[39m\"\u001b[39m)\n",
197
+ "\u001b[0;31mAttributeError\u001b[0m: 'function' object has no attribute 'load'"
198
+ ]
199
+ }
200
+ ],
201
+ "source": [
202
+ "rouge_metric = evaluate.load(\"rouge\")"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 14,
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "rouge_result = rouge_metric.compute(predictions=predictions, references=references)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 15,
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "data": {
221
+ "text/plain": [
222
+ "{'rouge1': 0.30500784605917763,\n",
223
+ " 'rouge2': 0.08778194034686765,\n",
224
+ " 'rougeL': 0.2707178803695874,\n",
225
+ " 'rougeLsum': 0.27045227295118685}"
226
+ ]
227
+ },
228
+ "execution_count": 15,
229
+ "metadata": {},
230
+ "output_type": "execute_result"
231
+ }
232
+ ],
233
+ "source": [
234
+ "rouge_result"
235
+ ]
236
+ }
237
+ ],
238
+ "metadata": {
239
+ "kernelspec": {
240
+ "display_name": "base",
241
+ "language": "python",
242
+ "name": "python3"
243
+ },
244
+ "language_info": {
245
+ "codemirror_mode": {
246
+ "name": "ipython",
247
+ "version": 3
248
+ },
249
+ "file_extension": ".py",
250
+ "mimetype": "text/x-python",
251
+ "name": "python",
252
+ "nbconvert_exporter": "python",
253
+ "pygments_lexer": "ipython3",
254
+ "version": "3.9.12"
255
+ },
256
+ "orig_nbformat": 4
257
+ },
258
+ "nbformat": 4,
259
+ "nbformat_minor": 2
260
+ }
test/subsample_test.ipynb ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import sys\n",
10
+ "sys.path.append(\"..\")"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 3,
16
+ "metadata": {},
17
+ "outputs": [
18
+ {
19
+ "name": "stderr",
20
+ "output_type": "stream",
21
+ "text": [
22
+ "/opt/conda/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
23
+ " from .autonotebook import tqdm as notebook_tqdm\n"
24
+ ]
25
+ }
26
+ ],
27
+ "source": [
28
+ "from modeling.audiobart import Subsampler"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 4,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "subsampler = Subsampler(1024, 3, 2)"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 5,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "from utils import count_parameters"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 6,
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "data": {
56
+ "text/plain": [
57
+ "1053696"
58
+ ]
59
+ },
60
+ "execution_count": 6,
61
+ "metadata": {},
62
+ "output_type": "execute_result"
63
+ }
64
+ ],
65
+ "source": [
66
+ "count_parameters(subsampler)"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 7,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "import torch"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 25,
81
+ "metadata": {},
82
+ "outputs": [
83
+ {
84
+ "name": "stdout",
85
+ "output_type": "stream",
86
+ "text": [
87
+ "torch.Size([8, 1023, 1024])\n"
88
+ ]
89
+ }
90
+ ],
91
+ "source": [
92
+ "input = torch.randn(8, 4095, 1024)\n",
93
+ "output = subsampler(input)\n",
94
+ "output = subsampler(output)\n",
95
+ "print(output.shape)"
96
+ ]
97
+ }
98
+ ],
99
+ "metadata": {
100
+ "kernelspec": {
101
+ "display_name": "base",
102
+ "language": "python",
103
+ "name": "python3"
104
+ },
105
+ "language_info": {
106
+ "codemirror_mode": {
107
+ "name": "ipython",
108
+ "version": 3
109
+ },
110
+ "file_extension": ".py",
111
+ "mimetype": "text/x-python",
112
+ "name": "python",
113
+ "nbconvert_exporter": "python",
114
+ "pygments_lexer": "ipython3",
115
+ "version": "3.9.12"
116
+ },
117
+ "orig_nbformat": 4
118
+ },
119
+ "nbformat": 4,
120
+ "nbformat_minor": 2
121
+ }
test/train_test.ipynb ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from datasets import load_dataset\n",
10
+ "from transformers import AutoTokenizer"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 3,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import os\n",
20
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"4,5,6,7\""
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 4,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "basepath = \"/data/jyk/aac_dataset/clotho/encodec/\"\n",
30
+ "tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 5,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "data_files = {\"train\": \"csv/train_short.csv\", \"validation\": \"csv/valid_short.csv\"}"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 6,
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "name": "stderr",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-8533483370f473b7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)\n",
52
+ "100%|██████████| 2/2 [00:00<00:00, 923.96it/s]\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "raw_dataset = load_dataset(\"csv\", data_files=data_files)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 7,
63
+ "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "data": {
67
+ "text/plain": [
68
+ "Dataset({\n",
69
+ " features: ['file_path', 'caption'],\n",
70
+ " num_rows: 19175\n",
71
+ "})"
72
+ ]
73
+ },
74
+ "execution_count": 7,
75
+ "metadata": {},
76
+ "output_type": "execute_result"
77
+ }
78
+ ],
79
+ "source": [
80
+ "raw_dataset['train']"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 10,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "from data.collator import EncodecCollator\n",
90
+ "import numpy as np\n",
91
+ "import os"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 11,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "def preprocessing(example):\n",
101
+ " path = example['file_path']\n",
102
+ " encodec = np.load(os.path.join(basepath, path))\n",
103
+ " if encodec.shape[0]>1022:\n",
104
+ " encodec = encodec[:1022, :]\n",
105
+ " attention_mask = np.ones(encodec.shape[0]+2)\n",
106
+ " target_text = tokenizer(text_target=example['caption'])\n",
107
+ "\n",
108
+ " return {'input_ids': encodec , 'attention_mask': attention_mask, 'labels': target_text['input_ids'], 'decoder_attention_mask': target_text['attention_mask']}\n"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 12,
114
+ "metadata": {},
115
+ "outputs": [
116
+ {
117
+ "name": "stderr",
118
+ "output_type": "stream",
119
+ "text": [
120
+ " \r"
121
+ ]
122
+ }
123
+ ],
124
+ "source": [
125
+ "train_dataset = raw_dataset['train'].map(preprocessing, num_proc=16)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 13,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "train_dataset.set_format(\"np\", columns=['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 32,
140
+ "metadata": {},
141
+ "outputs": [
142
+ {
143
+ "name": "stderr",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-8533483370f473b7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-a3db71731640afd3_*_of_00016.arrow\n"
147
+ ]
148
+ }
149
+ ],
150
+ "source": [
151
+ "valid_dataset = raw_dataset['validation'].map(preprocessing, num_proc=16)\n",
152
+ "valid_dataset.set_format(\"np\", columns=['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 16,
158
+ "metadata": {},
159
+ "outputs": [
160
+ {
161
+ "name": "stderr",
162
+ "output_type": "stream",
163
+ "text": [
164
+ "Some weights of AudioBartForConditionalGeneration were not initialized from the model checkpoint at bart/model and are newly initialized: ['model.encodec_embeddings.6.weight', 'model.encodec_embeddings.4.weight', 'model.encodec_embeddings.1.weight', 'model.encodec_embeddings.7.weight', 'model.encodec_embeddings.5.weight', 'model.encodec_embeddings.3.weight', 'model.encodec_embeddings.2.weight', 'model.encodec_embeddings.0.weight']\n",
165
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "from modeling.audiobart import AudioBartForConditionalGeneration\n",
171
+ "model = AudioBartForConditionalGeneration.from_pretrained('bart/model')"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 25,
177
+ "metadata": {},
178
+ "outputs": [
179
+ {
180
+ "name": "stdout",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "414688256\n"
184
+ ]
185
+ }
186
+ ],
187
+ "source": [
188
+ "from utils import count_parameters\n",
189
+ "print(count_parameters(model))"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 17,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "data_collator = EncodecCollator(tokenizer=tokenizer, model=model, return_tensors=\"pt\")"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 19,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 36,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "training_args = Seq2SeqTrainingArguments('summary_test', per_gpu_train_batch_size=16)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 37,
222
+ "metadata": {},
223
+ "outputs": [
224
+ {
225
+ "name": "stderr",
226
+ "output_type": "stream",
227
+ "text": [
228
+ "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.\n"
229
+ ]
230
+ }
231
+ ],
232
+ "source": [
233
+ "trainer = Seq2SeqTrainer(\n",
234
+ " model, training_args, train_dataset=valid_dataset, eval_dataset=valid_dataset, data_collator=data_collator, tokenizer=tokenizer\n",
235
+ ")"
236
+ ]
237
+ }
238
+ ],
239
+ "metadata": {
240
+ "kernelspec": {
241
+ "display_name": "base",
242
+ "language": "python",
243
+ "name": "python3"
244
+ },
245
+ "language_info": {
246
+ "codemirror_mode": {
247
+ "name": "ipython",
248
+ "version": 3
249
+ },
250
+ "file_extension": ".py",
251
+ "mimetype": "text/x-python",
252
+ "name": "python",
253
+ "nbconvert_exporter": "python",
254
+ "pygments_lexer": "ipython3",
255
+ "version": "3.9.12"
256
+ },
257
+ "orig_nbformat": 4
258
+ },
259
+ "nbformat": 4,
260
+ "nbformat_minor": 2
261
+ }
test/train_test.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer
3
+ from modeling.audiobart import AudioBartForConditionalGeneration
4
+ from data.collator import EncodecCollator
5
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
6
+
7
+ import numpy as np
8
+ import torch
9
+ import os
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7"
11
+
12
+ if __name__=="__main__":
13
+ model = AudioBartForConditionalGeneration.from_pretrained('bart/model')
14
+ basepath = "/data/jyk/aac_dataset/clotho/encodec/"
15
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large')
16
+ data_files = {"train": "csv/train_short.csv", "validation": "csv/valid_short.csv"}
17
+
18
+ raw_dataset = load_dataset("csv", data_files=data_files)
19
+
20
+ def preprocessing(example):
21
+ path = example['file_path']
22
+ encodec = np.load(os.path.join(basepath, path))
23
+ if encodec.shape[0]>1022:
24
+ encodec = encodec[:1022, :]
25
+ attention_mask = np.ones(encodec.shape[0]+2)
26
+ target_text = tokenizer(text_target=example['caption'])
27
+
28
+ return {'input_ids': encodec , 'attention_mask': attention_mask, 'labels': target_text['input_ids'], 'decoder_attention_mask': target_text['attention_mask']}
29
+
30
+ train_dataset = raw_dataset['validation'].map(preprocessing)
31
+ train_dataset.set_format("pt", columns=['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'])
32
+
33
+ data_collator = EncodecCollator(tokenizer=tokenizer, model=model, return_tensors="pt")
34
+
35
+ training_args = Seq2SeqTrainingArguments('summary_test', per_gpu_train_batch_size=20)
36
+
37
+ trainer = Seq2SeqTrainer(
38
+ model, training_args, train_dataset=train_dataset, eval_dataset=train_dataset, data_collator=data_collator, tokenizer=tokenizer
39
+ )
40
+
41
+ trainer.train()
train.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import logging
17
+ import math
18
+ import os
19
+ import sys
20
+
21
+ import datasets
22
+ import numpy as np
23
+ import torch
24
+ import transformers
25
+ from aac_metrics import evaluate
26
+ from accelerate import Accelerator, DistributedDataParallelKwargs
27
+ from accelerate.logging import get_logger
28
+ from accelerate.utils import set_seed
29
+ from datasets import load_dataset
30
+ from omegaconf import OmegaConf
31
+ from torch.utils.data import DataLoader
32
+ from tqdm.auto import tqdm
33
+ from transformers import (
34
+ AutoTokenizer,
35
+ BartConfig,
36
+ get_inverse_sqrt_schedule,
37
+ get_scheduler,
38
+ )
39
+
40
+ from data.collator import DataCollatorForEnClapBart
41
+ from data.preprocess import Preprocessor
42
+ from modeling.enclap_bart import EnClapBartForConditionalGeneration
43
+
44
+ logger = get_logger(__name__)
45
+ metric_list = ["meteor", "spider"]
46
+
47
+
48
+ def main():
49
+ # Load Configuration
50
+ cfg_path = sys.argv[1]
51
+ args = OmegaConf.load(cfg_path)
52
+
53
+ # Initialize Logging
54
+ accelerator_log_kwargs = {}
55
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
56
+ if args.with_tracking:
57
+ accelerator_log_kwargs["log_with"] = args.report_to
58
+ accelerator_log_kwargs["project_dir"] = args.output_dir
59
+
60
+ # Initialize Accelerator
61
+ accelerator = Accelerator(
62
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
63
+ split_batches=args.split_batches,
64
+ kwargs_handlers=[ddp_kwargs],
65
+ **accelerator_log_kwargs,
66
+ )
67
+ # Handle the repository creation
68
+ if accelerator.is_main_process:
69
+ if args.output_dir is not None:
70
+ os.makedirs(args.output_dir, exist_ok=True)
71
+ with open(os.path.join(args.output_dir, "args.yaml"), "w") as f:
72
+ OmegaConf.save(args, f)
73
+ accelerator.wait_for_everyone()
74
+
75
+ # Make one log on every process with the configuration for debugging.
76
+ logging.basicConfig(
77
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
78
+ datefmt="%m/%d/%Y %H:%M:%S",
79
+ level=logging.INFO,
80
+ )
81
+ file_handler = logging.FileHandler(os.path.join(args.output_dir, "train_log.txt"))
82
+ logger.logger.addHandler(file_handler)
83
+ logger.info(accelerator.state, main_process_only=False)
84
+ if accelerator.is_local_main_process:
85
+ datasets.utils.logging.set_verbosity_warning()
86
+ transformers.utils.logging.set_verbosity_warning()
87
+ else:
88
+ datasets.utils.logging.set_verbosity_error()
89
+ transformers.utils.logging.set_verbosity_error()
90
+
91
+ # If passed along, set the training seed now.
92
+ if args.seed is not None:
93
+ set_seed(args.seed)
94
+
95
+ # Get the datasets
96
+ data_files = {}
97
+ data_files_eval = {}
98
+ if args.train_file is not None:
99
+ data_files["train"] = args.train_file
100
+ if args.validation_file is not None:
101
+ data_files_eval["validation"] = args.validation_file
102
+
103
+ extension = args.train_file.split(".")[-1]
104
+ raw_datasets = load_dataset(extension, data_files=data_files)
105
+ raw_datasets_eval = load_dataset(extension, data_files=data_files_eval)
106
+
107
+ # Load pretrained model and tokenizer
108
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
109
+ if args.config_name_or_path is not None:
110
+ config = BartConfig.from_pretrained(args.config_name_or_path)
111
+ else:
112
+ config = None
113
+
114
+ if args.model_name_or_path is not None:
115
+ if config is None:
116
+ model = EnClapBartForConditionalGeneration.from_pretrained(
117
+ args.model_name_or_path
118
+ )
119
+ else:
120
+ model = EnClapBartForConditionalGeneration.from_pretrained(
121
+ args.model_name_or_path, config=config
122
+ )
123
+ else:
124
+ model = EnClapBartForConditionalGeneration(config=config)
125
+
126
+ # Set the generation config
127
+ if args.val_max_target_length is None:
128
+ args.val_max_target_length = args.max_target_length
129
+
130
+ # Set max encodec length based on the shape of the positional encoding
131
+ max_encodec_length = model.config.max_position_embeddings - 2
132
+ label_pad_token_id = (
133
+ -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
134
+ )
135
+ preprocessor = Preprocessor(
136
+ args.encodec_base_path,
137
+ args.clap_base_path,
138
+ tokenizer,
139
+ model.config.max_position_embeddings,
140
+ args.encodec_masking_prob,
141
+ args.encodec_masking_span,
142
+ label_pad_token_id,
143
+ model.config.encodec_vocab_size,
144
+ args.eval_num_captions,
145
+ )
146
+
147
+ with accelerator.main_process_first():
148
+ train_dataset = raw_datasets["train"].map(
149
+ preprocessor.preprocess_train,
150
+ num_proc=args.preprocessing_num_workers,
151
+ load_from_cache_file=not args.overwrite_cache,
152
+ desc="Running tokenizer on dataset",
153
+ )
154
+ train_dataset.set_format(
155
+ "pt",
156
+ columns=[
157
+ "input_ids",
158
+ "attention_mask",
159
+ "clap",
160
+ "labels",
161
+ "decoder_attention_mask",
162
+ ],
163
+ )
164
+
165
+ # Temporarily set max_target_length for validation.
166
+ eval_dataset = raw_datasets_eval["validation"].map(
167
+ preprocessor.preprocess_eval,
168
+ num_proc=args.preprocessing_num_workers,
169
+ load_from_cache_file=not args.overwrite_cache,
170
+ desc="Running tokenizer on dataset",
171
+ )
172
+ eval_dataset.set_format(
173
+ "pt",
174
+ columns=["input_ids", "attention_mask", "clap"],
175
+ output_all_columns=True,
176
+ )
177
+
178
+ train_data_collator = DataCollatorForEnClapBart(
179
+ tokenizer=tokenizer,
180
+ model=model,
181
+ return_tensors="pt",
182
+ label_pad_token_id=label_pad_token_id,
183
+ max_length=max_encodec_length,
184
+ encodec_masking_prob=args.encodec_masking_prob,
185
+ encodec_masking_span=args.encodec_masking_span,
186
+ )
187
+ valid_data_collator = DataCollatorForEnClapBart(
188
+ tokenizer=tokenizer,
189
+ model=model,
190
+ return_tensors="pt",
191
+ label_pad_token_id=label_pad_token_id,
192
+ max_length=max_encodec_length,
193
+ )
194
+
195
+ train_dataloader = DataLoader(
196
+ train_dataset,
197
+ shuffle=True,
198
+ collate_fn=train_data_collator,
199
+ batch_size=args.per_device_train_batch_size,
200
+ )
201
+ eval_dataloader = DataLoader(
202
+ eval_dataset,
203
+ collate_fn=valid_data_collator,
204
+ batch_size=args.per_device_eval_batch_size,
205
+ )
206
+
207
+ # Optimizer
208
+ # Split weights in two groups, one with weight decay and the other not.
209
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
210
+ optimizer_grouped_parameters = [
211
+ {
212
+ "params": [
213
+ p
214
+ for n, p in model.named_parameters()
215
+ if not any(nd in n for nd in no_decay)
216
+ ],
217
+ "weight_decay": args.weight_decay,
218
+ },
219
+ {
220
+ "params": [
221
+ p
222
+ for n, p in model.named_parameters()
223
+ if any(nd in n for nd in no_decay)
224
+ ],
225
+ "weight_decay": 0.0,
226
+ },
227
+ ]
228
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
229
+
230
+ # Scheduler and math around the number of training steps.
231
+ overrode_max_train_steps = False
232
+ num_update_steps_per_epoch = math.ceil(
233
+ len(train_dataloader) / args.gradient_accumulation_steps
234
+ )
235
+ if args.max_train_steps is None:
236
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
237
+ overrode_max_train_steps = True
238
+
239
+ if args.lr_scheduler_type == "inverse_sqrt" and hasattr(args, "time_scale"):
240
+ lr_scheduler = get_inverse_sqrt_schedule(
241
+ optimizer=optimizer,
242
+ num_warmup_steps=args.num_warmup_steps,
243
+ timescale=args.time_scale,
244
+ )
245
+ else:
246
+ lr_scheduler = get_scheduler(
247
+ name=args.lr_scheduler_type,
248
+ optimizer=optimizer,
249
+ num_warmup_steps=args.num_warmup_steps,
250
+ num_training_steps=args.max_train_steps,
251
+ )
252
+
253
+ # Prepare everything with our `accelerator`.
254
+ (
255
+ model,
256
+ optimizer,
257
+ train_dataloader,
258
+ eval_dataloader,
259
+ lr_scheduler,
260
+ ) = accelerator.prepare(
261
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
262
+ )
263
+
264
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
265
+ num_update_steps_per_epoch = math.ceil(
266
+ len(train_dataloader) / args.gradient_accumulation_steps
267
+ )
268
+ if overrode_max_train_steps:
269
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
270
+ # Afterwards we recalculate our number of training epochs
271
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
272
+
273
+ # Figure out how many steps we should save the Accelerator states
274
+ checkpointing_steps = args.checkpointing_steps
275
+ if checkpointing_steps is not None and checkpointing_steps.isdigit():
276
+ checkpointing_steps = int(checkpointing_steps)
277
+
278
+ # The trackers initializes automatically on the main process.
279
+ if args.with_tracking:
280
+ accelerator.init_trackers(args.logging_dir)
281
+
282
+ # Train!
283
+ total_batch_size = (
284
+ args.per_device_train_batch_size
285
+ * accelerator.num_processes
286
+ * args.gradient_accumulation_steps
287
+ )
288
+
289
+ if args.split_batches:
290
+ total_batch_size = int(total_batch_size / accelerator.num_processes)
291
+
292
+ logger.info("***** Running training *****")
293
+ logger.info(f" Num examples = {len(train_dataset)}")
294
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
295
+ logger.info(
296
+ f" Instantaneous batch size per device = {args.per_device_train_batch_size}"
297
+ )
298
+ logger.info(
299
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
300
+ )
301
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
302
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
303
+
304
+ completed_steps = 0
305
+ starting_epoch = 0
306
+ # Potentially load in the weights and states from a previous save
307
+ if not args.overwrite_output_dir and os.path.exists(
308
+ os.path.join(args.output_dir, "checkpoints")
309
+ ):
310
+ if args.resume_from_checkpoint is not None:
311
+ accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
312
+ accelerator.load_state(args.resume_from_checkpoint)
313
+ path = os.path.basename(args.resume_from_checkpoint)
314
+ else:
315
+ # Get the most recent checkpoint
316
+ dirs = [
317
+ f
318
+ for f in os.scandir(os.path.join(args.output_dir, "checkpoints"))
319
+ if f.is_dir()
320
+ ]
321
+ dirs.sort(key=os.path.getctime)
322
+ path = dirs[
323
+ -1
324
+ ].name # Sorts folders by date modified, most recent checkpoint is the last
325
+ accelerator.print(f"Resumed from checkpoint: {dirs[-1]}")
326
+ accelerator.load_state(dirs[-1])
327
+ # Extract `epoch_{i}` or `step_{i}`
328
+ training_difference = os.path.splitext(path)[0]
329
+
330
+ if "epoch" in training_difference:
331
+ starting_epoch = int(training_difference.replace("epoch_", "")) + 1
332
+ resume_step = None
333
+ completed_steps = starting_epoch * num_update_steps_per_epoch
334
+ else:
335
+ # need to multiply `gradient_accumulation_steps` to reflect real steps
336
+ resume_step = (
337
+ int(training_difference.replace("step_", ""))
338
+ * args.gradient_accumulation_steps
339
+ )
340
+ starting_epoch = resume_step // len(train_dataloader)
341
+ resume_step -= starting_epoch * len(train_dataloader)
342
+ completed_steps = resume_step // args.gradient_accumulation_stepp
343
+
344
+ # update the progress_bar if load from checkpoint
345
+ if args.with_tracking:
346
+ total_loss = 0
347
+ logging_loss = 0
348
+ before_epoch_loss = 0
349
+
350
+ if args.encodec_masking_prob > 0:
351
+ total_encodec_loss = 0
352
+ logging_encodec_loss = 0
353
+ before_epoch_encodec_loss = 0
354
+
355
+ for epoch in range(starting_epoch, args.num_train_epochs):
356
+ model.train()
357
+ if (
358
+ args.resume_from_checkpoint
359
+ and epoch == starting_epoch
360
+ and resume_step is not None
361
+ ):
362
+ # We skip the first `n` batches in the dataloader when resuming from a checkpoint
363
+ active_dataloader = accelerator.skip_first_batches(
364
+ train_dataloader, resume_step
365
+ )
366
+ else:
367
+ active_dataloader = train_dataloader
368
+ logger.info(f"***** Running epoch {epoch} *****")
369
+ epoch_iterator = tqdm(
370
+ active_dataloader,
371
+ desc="Training",
372
+ disable=not accelerator.is_local_main_process,
373
+ dynamic_ncols=True,
374
+ colour="CYAN",
375
+ )
376
+ for step, batch in enumerate(epoch_iterator):
377
+ with accelerator.accumulate(model):
378
+ outputs = model(**batch)
379
+ loss = outputs.loss
380
+ # We keep track of the loss at each epoch
381
+ if args.with_tracking:
382
+ total_loss += outputs.lm_loss.item()
383
+ if args.encodec_masking_prob > 0:
384
+ if outputs.encodec_loss is not None:
385
+ total_encodec_loss += outputs.encodec_loss.item()
386
+ accelerator.backward(loss)
387
+ if accelerator.sync_gradients:
388
+ accelerator.clip_grad_norm_(
389
+ model.parameters(), max_norm=args.max_grad_norm
390
+ )
391
+ optimizer.step()
392
+ lr_scheduler.step()
393
+ optimizer.zero_grad()
394
+
395
+ # Checks if the accelerator has performed an optimization step behind the scenes
396
+ if accelerator.sync_gradients:
397
+ completed_steps += 1
398
+ # Add loss information to tqdm
399
+ epoch_iterator.set_postfix(loss=total_loss / completed_steps)
400
+
401
+ if completed_steps % args.logging_steps == 0:
402
+ train_log = {
403
+ "train/learning_rate": lr_scheduler.get_last_lr()[0]
404
+ }
405
+ train_log["train/loss"] = (
406
+ total_loss - logging_loss
407
+ ) / args.logging_steps
408
+ logging_loss = total_loss
409
+ if args.encodec_masking_prob > 0:
410
+ train_log["train/encodec_loss"] = (
411
+ total_encodec_loss - logging_encodec_loss
412
+ ) / args.logging_steps
413
+ logging_encodec_loss = total_encodec_loss
414
+ accelerator.log(train_log, step=completed_steps)
415
+
416
+ if isinstance(checkpointing_steps, int):
417
+ if completed_steps % checkpointing_steps == 0:
418
+ output_dir = f"step_{completed_steps }"
419
+ if args.output_dir is not None:
420
+ output_dir = os.path.join(
421
+ args.output_dir, "checkpoints", output_dir
422
+ )
423
+ accelerator.save_state(output_dir)
424
+
425
+ if completed_steps >= args.max_train_steps:
426
+ break
427
+
428
+ model.eval()
429
+ gen_kwargs = {
430
+ "max_length": args.val_max_target_length,
431
+ }
432
+ predictions = []
433
+ references = []
434
+ eval_iterator = tqdm(
435
+ eval_dataloader,
436
+ desc="Validation",
437
+ disable=not accelerator.is_local_main_process,
438
+ dynamic_ncols=True,
439
+ colour="MAGENTA",
440
+ )
441
+ for step, batch in enumerate(eval_iterator):
442
+ # Drop the padded samples of the last batch of dataloader
443
+ # try:
444
+ # if accelerator.gradient_state.end_of_dataloader and accelerator.gradient_state.remainder > 0:
445
+ # batch = batch[:accelerator.gradient_state.remainder]
446
+ # except:
447
+ # pass
448
+
449
+ with torch.no_grad():
450
+ batch["input_ids"] = batch["input_ids"].cuda()
451
+ batch["clap"] = batch["clap"].cuda()
452
+ batch["attention_mask"] = batch["attention_mask"].cuda()
453
+ batch["eos_mask"] = batch["eos_mask"].cuda()
454
+
455
+ generated_tokens = accelerator.unwrap_model(model).generate(
456
+ batch["input_ids"],
457
+ clap=batch["clap"],
458
+ attention_mask=batch["attention_mask"],
459
+ eos_mask=batch["eos_mask"],
460
+ **gen_kwargs,
461
+ )
462
+
463
+ generated_tokens = accelerator.pad_across_processes(
464
+ generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
465
+ )
466
+ generated_tokens = generated_tokens.cpu().numpy()
467
+ captions = batch["captions"]
468
+
469
+ if isinstance(generated_tokens, tuple):
470
+ generated_tokens = generated_tokens[0]
471
+ decoded_preds = tokenizer.batch_decode(
472
+ generated_tokens, skip_special_tokens=True
473
+ )
474
+
475
+ predictions.extend(decoded_preds)
476
+ references.extend(captions)
477
+
478
+ logger.info("Evaluating predictions...")
479
+ result = evaluate(predictions, references, metrics=metric_list)
480
+
481
+ # Gather Result
482
+ result = {k: v.cuda() for k, v in result[0].items()}
483
+ result = accelerator.gather_for_metrics(result)
484
+ # Log the average of metrics among the processes
485
+ if accelerator.num_processes > 1:
486
+ result = {f"eval/{k}": round(v.mean().item(), 4) for k, v in result.items()}
487
+ else:
488
+ result = {f"eval/{k}": round(v.item(), 4) for k, v in result.items()}
489
+ logger.info(result)
490
+
491
+ if args.with_tracking:
492
+ result["train/epoch_train_loss"] = (total_loss - before_epoch_loss) / len(
493
+ train_dataloader
494
+ )
495
+ result["train/steps"] = completed_steps
496
+ before_epoch_loss = total_loss
497
+ if args.encodec_masking_prob > 0:
498
+ result["train/epoch_encodec_loss"] = (
499
+ total_encodec_loss - before_epoch_encodec_loss
500
+ ) / len(train_dataloader)
501
+ before_epoch_encodec_loss = total_encodec_loss
502
+ accelerator.log(result, step=epoch)
503
+
504
+ if args.checkpointing_steps == "epoch":
505
+ output_dir = f"epoch_{epoch}"
506
+ if args.output_dir is not None:
507
+ output_dir = os.path.join(args.output_dir, "checkpoints", output_dir)
508
+ accelerator.save_state(output_dir)
509
+ if accelerator.is_main_process:
510
+ unwrapped_model = accelerator.unwrap_model(model)
511
+ unwrapped_model.config.save_pretrained(output_dir)
512
+
513
+ if args.output_dir is not None:
514
+ save_dir = os.path.join(args.output_dir, "final")
515
+ accelerator.wait_for_everyone()
516
+ unwrapped_model = accelerator.unwrap_model(model)
517
+ unwrapped_model.save_pretrained(
518
+ save_dir,
519
+ is_main_process=accelerator.is_main_process,
520
+ save_function=accelerator.save,
521
+ )
522
+ if accelerator.is_main_process:
523
+ tokenizer.save_pretrained(save_dir)
524
+
525
+
526
+ if __name__ == "__main__":
527
+ main()
train.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ CFG_PATH="cfg/clotho/base.yaml"
2
+ accelerate launch --multi_gpu --main_process_port=1200 train.py $CFG_PATH