diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..533264f69d41465d4f322ed6e7fadac8ee60a16d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+.venv
+.env
+.cache
+__pycache__
+data/audio/*.wav
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..b54f9f40e9105b3d97c2dce2dbb66759d293af12
--- /dev/null
+++ b/app.py
@@ -0,0 +1,52 @@
+import streamlit as st
+from streamlit import session_state as session
+from src.config.configs import ProjectPaths
+import numpy as np
+from src.laion_clap.inference import AudioEncoder
+
+
+@st.cache(persist=True, show_spinner=False, suppress_st_warning=True)
+def load_data():
+ vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
+ return vectors
+
+
+recommender = AudioEncoder()
+audio_vectors = load_data()
+
+dataframe = None
+
+st.title("""
+Curate me a Playlist.
+ """)
+
+st.text("")
+st.text("")
+st.text("")
+st.text("")
+
+session.text_input = st.text(label="Describe a playlist")
+
+st.text("")
+st.text("")
+
+session.slider_count = st.slider(label="movie_count", min_value=5, max_value=50)
+
+st.text("")
+st.text("")
+
+buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
+
+is_clicked = col1.button(label="Curate")
+
+if is_clicked:
+ text_embed = recommender.get_text_embedding(session.text_input)
+
+
+st.text("")
+st.text("")
+st.text("")
+st.text("")
+
+if dataframe is not None:
+ st.table(dataframe)
\ No newline at end of file
diff --git a/data/.DS_Store b/data/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..add7523c63fc1770d3b377991b1f9c2bb3e70725
Binary files /dev/null and b/data/.DS_Store differ
diff --git a/data/audio/.gitkeep b/data/audio/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/json/saved_tracks.json b/data/json/saved_tracks.json
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/vectors/audio_representations.npy b/data/vectors/audio_representations.npy
new file mode 100644
index 0000000000000000000000000000000000000000..84544edd82ab9a414b723384bc1e89ab995038e7
--- /dev/null
+++ b/data/vectors/audio_representations.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe4a3ff8cfd2a6b13407352868f3f74fb290ebc11e8473e7132dd4bf947108da
+size 1290368
diff --git a/model_checkpoints/.gitkeep b/model_checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model_checkpoints/music_audioset_epoch_15_esc_90.14.pt b/model_checkpoints/music_audioset_epoch_15_esc_90.14.pt
new file mode 100644
index 0000000000000000000000000000000000000000..09274ba1b6f219de82e4265777848c7a41747e9e
--- /dev/null
+++ b/model_checkpoints/music_audioset_epoch_15_esc_90.14.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fae3e9c087f2909c28a09dc31c8dfcdacbc42ba44c70e972b58c1bd1caf6dedd
+size 2352471003
diff --git a/notebooks/notebook.ipynb b/notebooks/notebook.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ae577ada192cbc7241641753552e272c2cf98f27
--- /dev/null
+++ b/notebooks/notebook.ipynb
@@ -0,0 +1,788 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import librosa\n",
+ "import torch\n",
+ "from src import laion_clap\n",
+ "from glob import glob\n",
+ "import pandas as pd\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']\n",
+ "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Load the specified checkpoint music_audioset_epoch_15_esc_90.14.pt from users.\n",
+ "Load Checkpoint...\n",
+ "logit_scale_a \t Loaded\n",
+ "logit_scale_t \t Loaded\n",
+ "audio_branch.spectrogram_extractor.stft.conv_real.weight \t Loaded\n",
+ "audio_branch.spectrogram_extractor.stft.conv_imag.weight \t Loaded\n",
+ "audio_branch.logmel_extractor.melW \t Loaded\n",
+ "audio_branch.bn0.weight \t Loaded\n",
+ "audio_branch.bn0.bias \t Loaded\n",
+ "audio_branch.patch_embed.proj.weight \t Loaded\n",
+ "audio_branch.patch_embed.proj.bias \t Loaded\n",
+ "audio_branch.patch_embed.norm.weight \t Loaded\n",
+ "audio_branch.patch_embed.norm.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.norm1.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.norm1.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.norm2.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.norm2.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.0.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.norm1.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.norm1.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.norm2.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.norm2.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.0.blocks.1.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.0.downsample.reduction.weight \t Loaded\n",
+ "audio_branch.layers.0.downsample.norm.weight \t Loaded\n",
+ "audio_branch.layers.0.downsample.norm.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.norm1.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.norm1.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.norm2.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.norm2.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.0.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.norm1.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.norm1.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.norm2.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.norm2.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.1.blocks.1.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.1.downsample.reduction.weight \t Loaded\n",
+ "audio_branch.layers.1.downsample.norm.weight \t Loaded\n",
+ "audio_branch.layers.1.downsample.norm.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.0.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.1.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.2.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.3.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.4.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.5.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.6.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.7.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.8.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.9.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.10.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.norm1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.norm1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.norm2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.norm2.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.2.blocks.11.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.2.downsample.reduction.weight \t Loaded\n",
+ "audio_branch.layers.2.downsample.norm.weight \t Loaded\n",
+ "audio_branch.layers.2.downsample.norm.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.norm1.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.norm1.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.norm2.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.norm2.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.0.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.norm1.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.norm1.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.attn.relative_position_bias_table \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.attn.qkv.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.attn.qkv.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.attn.proj.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.attn.proj.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.norm2.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.norm2.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.mlp.fc1.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.mlp.fc1.bias \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.mlp.fc2.weight \t Loaded\n",
+ "audio_branch.layers.3.blocks.1.mlp.fc2.bias \t Loaded\n",
+ "audio_branch.norm.weight \t Loaded\n",
+ "audio_branch.norm.bias \t Loaded\n",
+ "audio_branch.tscam_conv.weight \t Loaded\n",
+ "audio_branch.tscam_conv.bias \t Loaded\n",
+ "audio_branch.head.weight \t Loaded\n",
+ "audio_branch.head.bias \t Loaded\n",
+ "text_branch.embeddings.word_embeddings.weight \t Loaded\n",
+ "text_branch.embeddings.position_embeddings.weight \t Loaded\n",
+ "text_branch.embeddings.token_type_embeddings.weight \t Loaded\n",
+ "text_branch.embeddings.LayerNorm.weight \t Loaded\n",
+ "text_branch.embeddings.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.0.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.0.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.1.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.1.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.2.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.2.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.3.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.3.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.4.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.4.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.5.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.5.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.6.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.6.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.7.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.7.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.8.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.8.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.9.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.9.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.10.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.10.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.self.query.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.self.query.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.self.key.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.self.key.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.self.value.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.self.value.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.attention.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.intermediate.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.intermediate.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.output.dense.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.output.dense.bias \t Loaded\n",
+ "text_branch.encoder.layer.11.output.LayerNorm.weight \t Loaded\n",
+ "text_branch.encoder.layer.11.output.LayerNorm.bias \t Loaded\n",
+ "text_branch.pooler.dense.weight \t Loaded\n",
+ "text_branch.pooler.dense.bias \t Loaded\n",
+ "text_transform.sequential.0.weight \t Loaded\n",
+ "text_transform.sequential.0.bias \t Loaded\n",
+ "text_transform.sequential.3.weight \t Loaded\n",
+ "text_transform.sequential.3.bias \t Loaded\n",
+ "text_projection.0.weight \t Loaded\n",
+ "text_projection.0.bias \t Loaded\n",
+ "text_projection.2.weight \t Loaded\n",
+ "text_projection.2.bias \t Loaded\n",
+ "audio_transform.sequential.0.weight \t Loaded\n",
+ "audio_transform.sequential.0.bias \t Loaded\n",
+ "audio_transform.sequential.3.weight \t Loaded\n",
+ "audio_transform.sequential.3.bias \t Loaded\n",
+ "audio_projection.0.weight \t Loaded\n",
+ "audio_projection.0.bias \t Loaded\n",
+ "audio_projection.2.weight \t Loaded\n",
+ "audio_projection.2.bias \t Loaded\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base')\n",
+ "model.load_ckpt(ckpt=\"music_audioset_epoch_15_esc_90.14.pt\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def load_music_file(file_name):\n",
+ " audio_data, _ = librosa.load(file_name, sr=48000) # sample rate should be 48000\n",
+ " audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T)\n",
+ " # audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model\n",
+ " with torch.no_grad():\n",
+ " audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False)\n",
+ " return audio_embed\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "music_files = glob(\"/Users/berkayg/Codes/music-project/AudioCLIP/data/downloaded_tracks/*.wav\")[:100]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/sr/r72219hj06x_1xvw7hhd517h0000gn/T/ipykernel_18860/3009710654.py:2: UserWarning: PySoundFile failed. Trying audioread instead.\n",
+ " audio_data, _ = librosa.load(file_name, sr=48000) # sample rate should be 48000\n",
+ "/Users/berkayg/miniforge3/envs/playlist-curator/lib/python3.10/site-packages/librosa/core/audio.py:183: FutureWarning: librosa.core.audio.__audioread_load\n",
+ "\tDeprecated as of librosa version 0.10.0.\n",
+ "\tIt will be removed in librosa version 1.0.\n",
+ " y, sr_native = __audioread_load(path, offset, duration, dtype)\n"
+ ]
+ }
+ ],
+ "source": [
+ "music_data = np.zeros((len(music_files), 512), dtype=np.float32)\n",
+ "for m in range(music_data.shape[0]):\n",
+ " music_data[m] = load_music_file(music_files[m])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(1, 512)\n"
+ ]
+ }
+ ],
+ "source": [
+ "text_data = [\"This audio is a romantic song\"] \n",
+ "text_embed = model.get_text_embedding(text_data)\n",
+ "print(text_embed.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "song_names = [k.split(\"/\")[-1] for k in music_files]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([100, 1])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " ranking = torch.tensor(music_data) @ torch.tensor(text_embed).t()\n",
+ " ranking = ranking[:, 0].reshape(-1, 1)\n",
+ "print(ranking.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " This audio is a romantic song | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Coldplay - Charlie Brown.wav | \n",
+ " 0.400684 | \n",
+ "
\n",
+ " \n",
+ " Sam Smith - I'm Not The Only One.wav | \n",
+ " 0.373561 | \n",
+ "
\n",
+ " \n",
+ " Pink Floyd - The Great Gig In The Sky - 2011 Remastered Version.wav | \n",
+ " 0.371584 | \n",
+ "
\n",
+ " \n",
+ " Christina Aguilera - You Lost Me.wav | \n",
+ " 0.370390 | \n",
+ "
\n",
+ " \n",
+ " Lana Del Rey - Yayo.wav | \n",
+ " 0.370379 | \n",
+ "
\n",
+ " \n",
+ " Queen - It's A Hard Life - Remastered 2011.wav | \n",
+ " 0.348699 | \n",
+ "
\n",
+ " \n",
+ " Teoman - Haziran.wav | \n",
+ " 0.331220 | \n",
+ "
\n",
+ " \n",
+ " John Lennon - Imagine - Remastered 2010.wav | \n",
+ " 0.330397 | \n",
+ "
\n",
+ " \n",
+ " Sleeping At Last - Mars.wav | \n",
+ " 0.328770 | \n",
+ "
\n",
+ " \n",
+ " Adele - Someone Like You.wav | \n",
+ " 0.325650 | \n",
+ "
\n",
+ " \n",
+ " Coldplay - What If.wav | \n",
+ " 0.315717 | \n",
+ "
\n",
+ " \n",
+ " Adamlar - Orda Ortada.wav | \n",
+ " 0.306465 | \n",
+ "
\n",
+ " \n",
+ " Eric Clapton - Autumn Leaves.wav | \n",
+ " 0.305451 | \n",
+ "
\n",
+ " \n",
+ " Premiata Forneria Marconi - Impressioni di settembre.wav | \n",
+ " 0.295878 | \n",
+ "
\n",
+ " \n",
+ " Guthrie Govan - Lost in Rio.wav | \n",
+ " 0.284883 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " This audio is a romantic song\n",
+ "Coldplay - Charlie Brown.wav 0.400684\n",
+ "Sam Smith - I'm Not The Only One.wav 0.373561\n",
+ "Pink Floyd - The Great Gig In The Sky - 2011 Re... 0.371584\n",
+ "Christina Aguilera - You Lost Me.wav 0.370390\n",
+ "Lana Del Rey - Yayo.wav 0.370379\n",
+ "Queen - It's A Hard Life - Remastered 2011.wav 0.348699\n",
+ "Teoman - Haziran.wav 0.331220\n",
+ "John Lennon - Imagine - Remastered 2010.wav 0.330397\n",
+ "Sleeping At Last - Mars.wav 0.328770\n",
+ "Adele - Someone Like You.wav 0.325650\n",
+ "Coldplay - What If.wav 0.315717\n",
+ "Adamlar - Orda Ortada.wav 0.306465\n",
+ "Eric Clapton - Autumn Leaves.wav 0.305451\n",
+ "Premiata Forneria Marconi - Impressioni di sett... 0.295878\n",
+ "Guthrie Govan - Lost in Rio.wav 0.284883"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame(ranking, columns=[text_data[0]], index=song_names).nlargest(15, text_data[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "playlist-curator",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/orchestrate_audio_data.py b/orchestrate_audio_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..6299850291979f9347623673dc703d5966218090
--- /dev/null
+++ b/orchestrate_audio_data.py
@@ -0,0 +1,8 @@
+from src.data.spotify import list_personal_saved_tracks
+from src.data.get_yt_links import collect_youtube_links
+from src.data.pytuber import start_download_process
+
+if __name__ == "__main__":
+ list_personal_saved_tracks()
+ collect_youtube_links()
+ start_download_process()
diff --git a/recommender.py b/recommender.py
new file mode 100644
index 0000000000000000000000000000000000000000..357d8f96a63c2ffcda29249df6eefa1db4cb8911
--- /dev/null
+++ b/recommender.py
@@ -0,0 +1,11 @@
+from src.laion_clap.inference import AudioEncoder
+from src.config.configs import ProjectPaths
+from glob import glob
+
+recommender = AudioEncoder()
+# audio = recommender.extract_bulk_audio_representaions(save=False)
+result = recommender.get_text_embedding("This audio is a romantic song")
+music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav")))
+song_names = [k.split("/")[-1] for k in music_files]
+print(result)
+pass
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..76d8dbd8c27b89b4466374154b4b88b2d7d7c114
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,89 @@
+altair==5.1.2
+anyio==4.0.0
+appdirs==1.4.4
+async-timeout==4.0.3
+attrs==23.1.0
+audioread==3.0.1
+blinker==1.7.0
+braceexpand==0.1.7
+cachetools==5.3.2
+certifi==2023.7.22
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+docker-pycreds==0.4.0
+filelock==3.13.1
+fsspec==2023.10.0
+ftfy==6.1.1
+gitdb==4.0.11
+GitPython==3.1.40
+google-api-python-client==2.105.0
+google-auth-httplib2==0.1.1
+h11==0.14.0
+h5py==3.10.0
+httpcore==1.0.2
+httplib2==0.22.0
+httpx==0.25.1
+huggingface-hub==0.19.4
+idna==3.4
+Jinja2==3.1.2
+joblib==1.3.2
+jsonschema==4.20.0
+jsonschema-specifications==2023.11.1
+lazy_loader==0.3
+librosa==0.10.1
+llvmlite==0.41.1
+markdown-it-py==3.0.0
+MarkupSafe==2.1.3
+mdurl==0.1.2
+msgpack==1.0.7
+numba==0.58.1
+numpy==1.23.5
+pandas==2.1.3
+Pillow==10.1.0
+pooch==1.8.0
+progressbar==2.5
+protobuf==3.20.1
+pyarrow==14.0.1
+pycparser==2.21
+pydeck==0.8.1b0
+pytube==15.0.0
+pytz==2023.3.post1
+PyYAML==6.0.1
+redis==5.0.1
+referencing==0.31.0
+regex==2023.10.3
+requests==2.31.0
+rich==13.7.0
+rpds-py==0.13.0
+safetensors==0.4.0
+scikit-learn==1.3.2
+scipy==1.11.3
+sentry-sdk==1.35.0
+setproctitle==1.3.3
+smmap==5.0.1
+sniffio==1.3.0
+soundfile==0.12.1
+soxr==0.3.7
+spotipy==2.23.0
+streamlit==1.28.2
+tenacity==8.2.3
+threadpoolctl==3.2.0
+tokenizers==0.13.3
+toml==0.10.2
+toolz==0.12.0
+torch==1.11.0
+torchaudio==0.11.0
+torchlibrosa==0.1.0
+torchvision==0.12.0
+tqdm==4.66.1
+transformers==4.30.2
+tzdata==2023.3
+tzlocal==5.2
+uritemplate==4.1.1
+urllib3==2.1.0
+validators==0.22.0
+wandb==0.16.0
+webdataset==0.2.77
+wget==3.2
+youtube-search-python==1.6.6
\ No newline at end of file
diff --git a/src/config/__init__.py b/src/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/config/configs.py b/src/config/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e43f2327e4f40b62381b8703c8332cfe33d40e4
--- /dev/null
+++ b/src/config/configs.py
@@ -0,0 +1,16 @@
+from pathlib import Path
+from dataclasses import dataclass
+from os import getenv
+
+
+@dataclass
+class ProjectPaths:
+ ROOT: Path = Path(__file__).parents[2]
+ DATA_DIR: Path = ROOT.joinpath("data")
+ MODEL_PATH: Path = ROOT.joinpath("model_checkpoints", "music_audioset_epoch_15_esc_90.14.pt")
+
+
+@dataclass
+class Credentials:
+ SPOTIFY_CLIENT_ID: str = getenv("SPOTIFY_CLIENT_ID")
+ SPOTIFY_SECRET_ID: str = getenv("SPOTIFY_SECRET_ID")
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/data/get_yt_links.py b/src/data/get_yt_links.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7658f4b6d53c97559d9942ae31dd34bbeadb713
--- /dev/null
+++ b/src/data/get_yt_links.py
@@ -0,0 +1,52 @@
+from youtubesearchpython import VideosSearch
+import json
+import time
+from src.config.configs import ProjectPaths
+from tqdm import tqdm
+
+
+def read_json_data():
+ with open(ProjectPaths.DATA_DIR.joinpath("json", "saved_tracks.json"), "r") as rd:
+ data = json.load(rd)
+ return data
+
+
+def get_track_link(artist_name, track_name):
+ search_result = VideosSearch(f'{artist_name} - {track_name}', limit=1)
+ result = search_result.result()["result"][0]
+ data = {
+ "artist_name": artist_name,
+ "track_name": track_name,
+ "duration": result.get("duration"),
+ "published_time": result.get("publishedTime"),
+ "title": result.get("title"),
+ "view_count": result.get("viewCount").get("text"),
+ "link": result.get("link")
+ }
+ return data
+
+
+def save_youtube_data(data):
+ with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "w") as wr:
+ json.dump(data, wr, indent=4)
+
+
+def collect_youtube_links():
+ data = read_json_data()
+ youtube_data = []
+ for track_data in tqdm(data):
+ yt_data = get_track_link(track_data["artist"], track_data["track"])
+ youtube_data.append(yt_data)
+ time.sleep(0.2)
+ save_youtube_data(youtube_data)
+
+
+if __name__ == "__main__":
+ data = read_json_data()
+ youtube_data = []
+ for track_data in tqdm(data):
+ yt_data = get_track_link(track_data["artist"], track_data["track"])
+ youtube_data.append(yt_data)
+ time.sleep(0.2)
+ pass
+ save_youtube_data(youtube_data)
diff --git a/src/data/pytuber.py b/src/data/pytuber.py
new file mode 100644
index 0000000000000000000000000000000000000000..46d6081b459c668d228d9a02b85269fa2e811b22
--- /dev/null
+++ b/src/data/pytuber.py
@@ -0,0 +1,35 @@
+import os
+from src.config.configs import ProjectPaths
+import json
+import pytube
+from tqdm import tqdm
+from pytube.exceptions import AgeRestrictedError
+
+
+def read_youtube_data():
+ input_data = ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json")
+ with open(input_data, "r") as rd:
+ return json.load(rd)
+
+
+def download_mp3(link, download_path, track_full_name):
+ data_dir = ProjectPaths.DATA_DIR.joinpath("audio")
+ try:
+ mp3 = pytube.YouTube(link, use_oauth=True, allow_oauth_cache=True).streams.filter(only_audio=True).first()
+ mp3.download(data_dir)
+
+ new_file = track_full_name + '.wav'
+ os.rename(download_path.joinpath(mp3.default_filename), data_dir.joinpath(new_file))
+ except AgeRestrictedError:
+ pass
+
+
+def start_download_process():
+ input_data = read_youtube_data()
+ done_pieces = os.listdir(ProjectPaths.DATA_DIR.joinpath("audio"))
+ for i in tqdm(input_data):
+ link = i["link"]
+ full_name = f'{i["artist_name"]} - {i["track_name"]}'.replace("/", "_")
+ if full_name + ".wav" in done_pieces:
+ continue
+ download_mp3(link, full_name)
diff --git a/src/data/spotify.py b/src/data/spotify.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb98bc311ee79d18f0a69b5708589bde431a59c5
--- /dev/null
+++ b/src/data/spotify.py
@@ -0,0 +1,24 @@
+import spotipy
+from spotipy.oauth2 import SpotifyOAuth
+from ..config.configs import Credentials, ProjectPaths
+import json
+
+
+def list_personal_saved_tracks():
+ scope = "user-library-read"
+ auth = SpotifyOAuth(client_id=Credentials.SPOTIFY_CLIENT_ID, client_secret=Credentials.SPOTIFY_SECRET_ID, scope=scope, redirect_uri="https://localhost:5000")
+ sp = spotipy.Spotify(auth_manager=auth)
+
+ tracks = []
+ offset_count = 0
+ for _ in range(50):
+ results = sp.current_user_saved_tracks(limit=50, offset=offset_count)
+ for idx, item in enumerate(results['items']):
+ track = item['track']
+ data = {"artist": track['artists'][0]['name'], "track": track['name']}
+ tracks.append(data)
+ print(idx, track['artists'][0]['name'], " - ", track['name'])
+ offset_count += 50
+
+ with open(ProjectPaths.DATA_DIR.joinpath("json", "saved_tracks.json"), "w", encoding="UTF-8") as wr:
+ json.dump(tracks, wr, indent=4)
diff --git a/src/laion_clap/__init__.py b/src/laion_clap/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..96d4b618dcb091479e2c9092ea2b807527f239de
--- /dev/null
+++ b/src/laion_clap/__init__.py
@@ -0,0 +1,5 @@
+import os
+import sys
+dir_path = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(dir_path)
+from .hook import CLAP_Module
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/__init__.py b/src/laion_clap/clap_module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b585be6540fe21eef8bc6594375baee5017877ef
--- /dev/null
+++ b/src/laion_clap/clap_module/__init__.py
@@ -0,0 +1,8 @@
+from .factory import list_models, create_model, create_model_and_transforms, add_model_config
+from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
+from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model
+from .openai import load_openai_model, list_openai_models
+from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
+ get_pretrained_url, download_pretrained
+from .tokenizer import SimpleTokenizer, tokenize
+from .transform import image_transform
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/bert.py b/src/laion_clap/clap_module/bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..005e72dec67e4b1c05063dbd4d024166344fd2c4
--- /dev/null
+++ b/src/laion_clap/clap_module/bert.py
@@ -0,0 +1,32 @@
+from transformers import BertTokenizer, BertModel
+tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+model = BertModel.from_pretrained("bert-base-uncased")
+text = "Replace me by any text you'd like."
+
+def bert_embeddings(text):
+ # text = "Replace me by any text you'd like."
+ encoded_input = tokenizer(text, return_tensors='pt')
+ output = model(**encoded_input)
+ return output
+
+from transformers import RobertaTokenizer, RobertaModel
+
+tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
+model = RobertaModel.from_pretrained('roberta-base')
+text = "Replace me by any text you'd like."
+def Roberta_embeddings(text):
+ # text = "Replace me by any text you'd like."
+ encoded_input = tokenizer(text, return_tensors='pt')
+ output = model(**encoded_input)
+ return output
+
+from transformers import BartTokenizer, BartModel
+
+tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
+model = BartModel.from_pretrained('facebook/bart-base')
+text = "Replace me by any text you'd like."
+def bart_embeddings(text):
+ # text = "Replace me by any text you'd like."
+ encoded_input = tokenizer(text, return_tensors='pt')
+ output = model(**encoded_input)
+ return output
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz b/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/src/laion_clap/clap_module/factory.py b/src/laion_clap/clap_module/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e581ebcef11017eb5f48dd20682a066fd1a02fc
--- /dev/null
+++ b/src/laion_clap/clap_module/factory.py
@@ -0,0 +1,263 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+from packaging import version
+
+import torch
+import transformers
+
+from .model import CLAP, convert_weights_to_fp16
+from .openai import load_openai_model
+from .pretrained import get_pretrained_url, download_pretrained
+from .transform import image_transform
+
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = (".json",)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f"*{ext}"))
+
+ for cf in config_files:
+ with open(cf, "r") as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = {
+ k: v
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
+ }
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ if skip_params:
+ if next(iter(state_dict.items()))[0].startswith("module"):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # removing position_ids to maintain compatibility with latest transformers update
+ if version.parse(transformers.__version__) >= version.parse("4.31.0"):
+ del state_dict["text_branch.embeddings.position_ids"]
+ # for k in state_dict:
+ # if k.startswith('transformer'):
+ # v = state_dict.pop(k)
+ # state_dict['text_branch.' + k[12:]] = v
+ return state_dict
+
+
+def create_model(
+ amodel_name: str,
+ tmodel_name: str,
+ pretrained: str = "",
+ precision: str = "fp32",
+ device: torch.device = torch.device("cpu"),
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
+ skip_params=True,
+ pretrained_audio: str = "",
+ pretrained_text: str = "",
+ enable_fusion: bool = False,
+ fusion_type: str = 'None'
+ # pretrained_image: bool = False,
+):
+ amodel_name = amodel_name.replace(
+ "/", "-"
+ ) # for callers using old naming with / in ViT names
+ pretrained_orig = pretrained
+ pretrained = pretrained.lower()
+ if pretrained == "openai":
+ if amodel_name in _MODEL_CONFIGS:
+ logging.info(f"Loading {amodel_name} model config.")
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+ else:
+ logging.error(
+ f"Model config for {amodel_name} not found; available models {list_models()}."
+ )
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
+ # Hard Code in model name
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
+ model = load_openai_model(
+ "ViT-B-16",
+ model_cfg,
+ device=device,
+ jit=jit,
+ cache_dir=openai_model_cache_dir,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type
+ )
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
+ if precision == "amp" or precision == "fp32":
+ model = model.float()
+ else:
+ if amodel_name in _MODEL_CONFIGS:
+ logging.info(f"Loading {amodel_name} model config.")
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+ else:
+ logging.error(
+ f"Model config for {amodel_name} not found; available models {list_models()}."
+ )
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ # if pretrained_image:
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
+ # # pretrained weight loading for timm models set via vision_cfg
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
+ # else:
+ # assert False, 'pretrained image towers currently only supported for timm models'
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
+ model_cfg["enable_fusion"] = enable_fusion
+ model_cfg["fusion_type"] = fusion_type
+ model = CLAP(**model_cfg)
+
+ if pretrained:
+ checkpoint_path = ""
+ url = get_pretrained_url(amodel_name, pretrained)
+ if url:
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
+ elif os.path.exists(pretrained_orig):
+ checkpoint_path = pretrained_orig
+ if checkpoint_path:
+ logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).")
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
+ model.load_state_dict(ckpt)
+ param_names = [n for n, p in model.named_parameters()]
+ for n in param_names:
+ print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
+ else:
+ logging.warning(
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+ )
+ raise RuntimeError(
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+ )
+
+ if pretrained_audio:
+ if amodel_name.startswith('PANN'):
+ if 'Cnn14_mAP' in pretrained_audio: # official checkpoint
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
+ audio_ckpt = audio_ckpt['model']
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key:
+ v = audio_ckpt.pop(key)
+ audio_ckpt['audio_branch.' + key] = v
+ elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
+ audio_ckpt = audio_ckpt['state_dict']
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith('sed_model'):
+ v = audio_ckpt.pop(key)
+ audio_ckpt['audio_branch.' + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
+ else:
+ raise ValueError('Unknown audio checkpoint')
+ elif amodel_name.startswith('HTSAT'):
+ if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
+ audio_ckpt = audio_ckpt['state_dict']
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith('sed_model') and ('spectrogram_extractor' not in key
+ and 'logmel_extractor' not in key):
+ v = audio_ckpt.pop(key)
+ audio_ckpt['audio_branch.' + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
+ audio_ckpt = audio_ckpt['state_dict']
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith('sed_model'):
+ v = audio_ckpt.pop(key)
+ audio_ckpt['audio_branch.' + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location='cpu')
+ else:
+ raise ValueError('Unknown audio checkpoint')
+ else:
+ raise f'this audio encoder pretrained checkpoint is not support'
+
+ model.load_state_dict(audio_ckpt, strict=False)
+ logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).")
+ param_names = [n for n, p in model.named_parameters()]
+ for n in param_names:
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
+
+ model.to(device=device)
+ if precision == "fp16":
+ assert device.type != "cpu"
+ convert_weights_to_fp16(model)
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model, model_cfg
+
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: str = "",
+ precision: str = "fp32",
+ device: torch.device = torch.device("cpu"),
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ # pretrained_image: bool = False,
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision,
+ device,
+ jit,
+ force_quick_gelu=force_quick_gelu,
+ # pretrained_image=pretrained_image
+ )
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
+ return model, preprocess_train, preprocess_val
+
+
+def list_models():
+ """enumerate available model architectures based on config files"""
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """add model config path or file and update registry"""
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
diff --git a/src/laion_clap/clap_module/feature_fusion.py b/src/laion_clap/clap_module/feature_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2419516b76931f0aa801d78e1b5f04a92a909e6
--- /dev/null
+++ b/src/laion_clap/clap_module/feature_fusion.py
@@ -0,0 +1,193 @@
+'''
+Feature Fusion for Varible-Length Data Processing
+AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
+According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
+'''
+
+import torch
+import torch.nn as nn
+
+
+class DAF(nn.Module):
+ '''
+ 直接相加 DirectAddFuse
+ '''
+
+ def __init__(self):
+ super(DAF, self).__init__()
+
+ def forward(self, x, residual):
+ return x + residual
+
+
+class iAFF(nn.Module):
+ '''
+ 多特征融合 iAFF
+ '''
+
+ def __init__(self, channels=64, r=4, type='2D'):
+ super(iAFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ if type == '1D':
+ # 本地注意力
+ self.local_att = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+
+ # 全局注意力
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+
+ # 第二次本地注意力
+ self.local_att2 = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ # 第二次全局注意力
+ self.global_att2 = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ elif type == '2D':
+ # 本地注意力
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ # 全局注意力
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ # 第二次本地注意力
+ self.local_att2 = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ # 第二次全局注意力
+ self.global_att2 = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ else:
+ raise f'the type is not supported'
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, residual):
+ flag = False
+ xa = x + residual
+ if xa.size(0) == 1:
+ xa = torch.cat([xa,xa],dim=0)
+ flag = True
+ xl = self.local_att(xa)
+ xg = self.global_att(xa)
+ xlg = xl + xg
+ wei = self.sigmoid(xlg)
+ xi = x * wei + residual * (1 - wei)
+
+ xl2 = self.local_att2(xi)
+ xg2 = self.global_att(xi)
+ xlg2 = xl2 + xg2
+ wei2 = self.sigmoid(xlg2)
+ xo = x * wei2 + residual * (1 - wei2)
+ if flag:
+ xo = xo[0].unsqueeze(0)
+ return xo
+
+
+class AFF(nn.Module):
+ '''
+ 多特征融合 AFF
+ '''
+
+ def __init__(self, channels=64, r=4, type='2D'):
+ super(AFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ if type == '1D':
+ self.local_att = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ elif type == '2D':
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ else:
+ raise f'the type is not supported.'
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, residual):
+ flag = False
+ xa = x + residual
+ if xa.size(0) == 1:
+ xa = torch.cat([xa,xa],dim=0)
+ flag = True
+ xl = self.local_att(xa)
+ xg = self.global_att(xa)
+ xlg = xl + xg
+ wei = self.sigmoid(xlg)
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
+ if flag:
+ xo = xo[0].unsqueeze(0)
+ return xo
+
diff --git a/src/laion_clap/clap_module/htsat.py b/src/laion_clap/clap_module/htsat.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb8e7cf5f2307c57e094a122121f3ca7f527436a
--- /dev/null
+++ b/src/laion_clap/clap_module/htsat.py
@@ -0,0 +1,1031 @@
+# Ke Chen
+# knutchen@ucsd.edu
+# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
+# Some layers designed on the model
+# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
+# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from itertools import repeat
+import collections.abc
+import math
+import warnings
+
+from torch.nn.init import _calculate_fan_in_and_fan_out
+import torch.utils.checkpoint as checkpoint
+
+import random
+
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from itertools import repeat
+from .utils import do_mixup, interpolate
+
+from .feature_fusion import iAFF, AFF, DAF
+
+# from PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16,
+ enable_fusion=False, fusion_type='None'):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patch_stride = to_2tuple(patch_stride)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
+
+ if (self.enable_fusion) and (self.fusion_type == 'channel_map'):
+ self.proj = nn.Conv2d(in_chans*4, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
+ else:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
+ self.mel_conv2d = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_size[0], patch_size[1]*3), stride=(patch_stride[0], patch_stride[1] * 3), padding=padding)
+ if self.fusion_type == 'daf_2d':
+ self.fusion_model = DAF()
+ elif self.fusion_type == 'aff_2d':
+ self.fusion_model = AFF(channels=embed_dim, type='2D')
+ elif self.fusion_type == 'iaff_2d':
+ self.fusion_model = iAFF(channels=embed_dim, type='2D')
+ def forward(self, x, longer_idx = None):
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
+ global_x = x[:,0:1,:,:]
+
+
+ # global processing
+ B, C, H, W = global_x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ global_x = self.proj(global_x)
+ TW = global_x.size(-1)
+ if len(longer_idx) > 0:
+ # local processing
+ local_x = x[longer_idx,1:,:,:].contiguous()
+ B, C, H, W = local_x.shape
+ local_x = local_x.view(B*C,1,H,W)
+ local_x = self.mel_conv2d(local_x)
+ local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3))
+ local_x = local_x.permute((0,2,3,1,4)).contiguous().flatten(3)
+ TB,TC,TH,_ = local_x.size()
+ if local_x.size(-1) < TW:
+ local_x = torch.cat([local_x, torch.zeros((TB,TC,TH,TW-local_x.size(-1)), device=global_x.device)], dim=-1)
+ else:
+ local_x = local_x[:,:,:,:TW]
+
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx],local_x)
+ x = global_x
+ else:
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == 'fan_in':
+ denom = fan_in
+ elif mode == 'fan_out':
+ denom = fan_out
+ elif mode == 'fan_avg':
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
+ elif distribution == "normal":
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x, attn
+
+ def extra_repr(self):
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+
+# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.norm_before_mlp = norm_before_mlp
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ if self.norm_before_mlp == 'ln':
+ self.norm2 = nn.LayerNorm(dim)
+ elif self.norm_before_mlp == 'bn':
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
+ else:
+ raise NotImplementedError
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ # pdb.set_trace()
+ H, W = self.input_resolution
+ # print("H: ", H)
+ # print("W: ", W)
+ # pdb.set_trace()
+ B, L, C = x.shape
+ # assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x, attn
+
+ def extra_repr(self):
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self):
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ norm_before_mlp='ln'):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ attns = []
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x, attn = blk(x)
+ if not self.training:
+ attns.append(attn.unsqueeze(0))
+ if self.downsample is not None:
+ x = self.downsample(x)
+ if not self.training:
+ attn = torch.cat(attns, dim = 0)
+ attn = torch.mean(attn, dim = 0)
+ return x, attn
+
+ def extra_repr(self):
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+# The Core of HTSAT
+class HTSAT_Swin_Transformer(nn.Module):
+ r"""HTSAT based on the Swin Transformer
+ Args:
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
+ in_chans (int): Number of input image channels. Default: 1 (mono)
+ num_classes (int): Number of classes for classification head. Default: 527
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 8
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ config (module): The configuration Module from config.py
+ """
+
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
+ in_chans=1, num_classes=527,
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False, patch_norm=True,
+ use_checkpoint=False, norm_before_mlp='ln', config = None,
+ enable_fusion = False, fusion_type = 'None', **kwargs):
+ super(HTSAT_Swin_Transformer, self).__init__()
+
+ self.config = config
+ self.spec_size = spec_size
+ self.patch_stride = patch_stride
+ self.patch_size = patch_size
+ self.window_size = window_size
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.ape = ape
+ self.in_chans = in_chans
+ self.num_classes = num_classes
+ self.num_heads = num_heads
+ self.num_layers = len(self.depths)
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.drop_rate = drop_rate
+ self.attn_drop_rate = attn_drop_rate
+ self.drop_path_rate = drop_path_rate
+
+ self.qkv_bias = qkv_bias
+ self.qk_scale = None
+
+ self.patch_norm = patch_norm
+ self.norm_layer = norm_layer if self.patch_norm else None
+ self.norm_before_mlp = norm_before_mlp
+ self.mlp_ratio = mlp_ratio
+
+ self.use_checkpoint = use_checkpoint
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # process mel-spec ; used only once
+ self.freq_ratio = self.spec_size // self.config.mel_bins
+ window = 'hann'
+ center = True
+ pad_mode = 'reflect'
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+ self.interpolate_ratio = 32 # Downsampled ratio
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
+ freeze_parameters=True)
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
+ freeze_parameters=True)
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
+
+
+ # split spctrogram into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride,
+ enable_fusion=self.enable_fusion, fusion_type=self.fusion_type
+ )
+
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.grid_size
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
+ patches_resolution[1] // (2 ** i_layer)),
+ depth=self.depths[i_layer],
+ num_heads=self.num_heads[i_layer],
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
+ norm_layer=self.norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ norm_before_mlp=self.norm_before_mlp)
+ self.layers.append(layer)
+
+ self.norm = self.norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
+
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
+ self.tscam_conv = nn.Conv2d(
+ in_channels = self.num_features,
+ out_channels = self.num_classes,
+ kernel_size = (SF,3),
+ padding = (0,1)
+ )
+ self.head = nn.Linear(num_classes, num_classes)
+
+ if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']):
+ self.mel_conv1d = nn.Sequential(
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+ nn.BatchNorm1d(64)
+ )
+ if self.fusion_type == 'daf_1d':
+ self.fusion_model = DAF()
+ elif self.fusion_type == 'aff_1d':
+ self.fusion_model = AFF(channels=64, type='1D')
+ elif self.fusion_type == 'iaff_1d':
+ self.fusion_model = iAFF(channels=64, type='1D')
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+
+ def forward_features(self, x, longer_idx = None):
+ # A deprecated optimization for using a hierarchical output from different blocks
+
+ frames_num = x.shape[2]
+ x = self.patch_embed(x, longer_idx = longer_idx)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ for i, layer in enumerate(self.layers):
+ x, attn = layer(x)
+ # for x
+ x = self.norm(x)
+ B, N, C = x.shape
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
+ B, C, F, T = x.shape
+ # group 2D CNN
+ c_freq_bin = F // self.freq_ratio
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
+ # get latent_output
+ fine_grained_latent_output = torch.mean(x, dim = 2)
+ fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
+
+ latent_output = self.avgpool(torch.flatten(x,2))
+ latent_output = torch.flatten(latent_output, 1)
+
+ # display the attention map, if needed
+
+ x = self.tscam_conv(x)
+ x = torch.flatten(x, 2) # B, C, T
+
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+
+ output_dict = {
+ 'framewise_output': fpx, # already sigmoided
+ 'clipwise_output': torch.sigmoid(x),
+ 'fine_grained_embedding': fine_grained_latent_output,
+ 'embedding': latent_output
+ }
+
+ return output_dict
+
+ def crop_wav(self, x, crop_size, spe_pos = None):
+ time_steps = x.shape[2]
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
+ for i in range(len(x)):
+ if spe_pos is None:
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
+ else:
+ crop_pos = spe_pos
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
+ return tx
+
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
+ def reshape_wav2img(self, x):
+ B, C, T, F = x.shape
+ target_T = int(self.spec_size * self.freq_ratio)
+ target_F = self.spec_size // self.freq_ratio
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
+ # to avoid bicubic zero error
+ if T < target_T:
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
+ if F < target_F:
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
+ x = x.permute(0,1,3,2).contiguous()
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
+ # print(x.shape)
+ x = x.permute(0,1,3,2,4).contiguous()
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
+ return x
+
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
+ def repeat_wat2img(self, x, cur_pos):
+ B, C, T, F = x.shape
+ target_T = int(self.spec_size * self.freq_ratio)
+ target_F = self.spec_size // self.freq_ratio
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
+ # to avoid bicubic zero error
+ if T < target_T:
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
+ if F < target_F:
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
+ x = x.permute(0,1,3,2).contiguous() # B C F T
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
+ x = x.repeat(repeats = (1,1,4,1))
+ return x
+
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None):
+
+ if self.enable_fusion and x["longer"].sum() == 0:
+ # if no audio is longer than 10s, then randomly select one audio to be longer
+ if self.training:
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
+ else:
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x, longer_idx=[])
+ return output_dict
+
+ if not self.enable_fusion:
+ x = x["waveform"].to(device=device, non_blocking=True)
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x)
+ else:
+ longer_list = x["longer"].to(device=device, non_blocking=True)
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ longer_list_idx = torch.where(longer_list)[0]
+ if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']:
+ new_x = x[:,0:1,:,:].clone().contiguous()
+ if len(longer_list_idx) > 0:
+ # local processing
+ fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous()
+ FB,FC,FT,FF = fusion_x_local.size()
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous()
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
+ fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1))
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2)
+ if fusion_x_local.size(-1) < FT:
+ fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1)
+ else:
+ fusion_x_local = fusion_x_local[:,:,:FT]
+ # 1D fusion
+ new_x = new_x.squeeze(1).permute((0,2,1)).contiguous()
+ new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local)
+ x = new_x.permute((0,2,1)).contiguous()[:,None,:,:]
+ else:
+ x = new_x
+
+ elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']:
+ x = x # no change
+
+ if self.training:
+ x = self.spec_augmenter(x)
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x, longer_idx = longer_list_idx)
+
+ # if infer_mode:
+ # # in infer mode. we need to handle different length audio input
+ # frame_num = x.shape[2]
+ # target_T = int(self.spec_size * self.freq_ratio)
+ # repeat_ratio = math.floor(target_T / frame_num)
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # else:
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
+ # if self.training:
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # else:
+ # # Change: Hard code here
+ # overlap_size = (x.shape[2] - 1) // 4
+ # output_dicts = []
+ # crop_size = (x.shape[2] - 1) // 2
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
+ # tx = self.reshape_wav2img(tx)
+ # output_dicts.append(self.forward_features(tx))
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
+ # for d in output_dicts:
+ # clipwise_output += d["clipwise_output"]
+ # framewise_output += d["framewise_output"]
+ # clipwise_output = clipwise_output / len(output_dicts)
+ # framewise_output = framewise_output / len(output_dicts)
+ # output_dict = {
+ # 'framewise_output': framewise_output,
+ # 'clipwise_output': clipwise_output
+ # }
+ # else: # this part is typically used, and most easy one
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # x = self.head(x)
+
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
+
+
+
+ return output_dict
+
+def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type='None'):
+ try:
+
+ assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!"
+ if audio_cfg.model_name == "tiny":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4,4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=96,
+ depths=[2,2,6,2],
+ num_heads=[4,8,16,32],
+ window_size=8,
+ config = audio_cfg,
+ enable_fusion = enable_fusion,
+ fusion_type = fusion_type
+ )
+ elif audio_cfg.model_name == "base":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4,4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=128,
+ depths=[2,2,12,2],
+ num_heads=[4,8,16,32],
+ window_size=8,
+ config = audio_cfg,
+ enable_fusion = enable_fusion,
+ fusion_type = fusion_type
+ )
+ elif audio_cfg.model_name == "large":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4,4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=256,
+ depths=[2,2,12,2],
+ num_heads=[4,8,16,32],
+ window_size=8,
+ config = audio_cfg,
+ enable_fusion = enable_fusion,
+ fusion_type = fusion_type
+ )
+
+ return model
+ except:
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
+
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/linear_probe.py b/src/laion_clap/clap_module/linear_probe.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb2841dd4e28201db8b5bd4a215e1b8b9a60d25a
--- /dev/null
+++ b/src/laion_clap/clap_module/linear_probe.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch.nn.functional as F
+from torch import nn
+from .model import MLPLayers
+
+
+class LinearProbe(nn.Module):
+ def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
+ """
+ Args:
+ model: nn.Module
+ mlp: bool, if True, then use the MLP layer as the linear probe module
+ freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
+ in_ch: int, the output channel from CLAP model
+ out_ch: int, the output channel from linear probe (class_num)
+ act: torch.nn.functional, the activation function before the loss function
+ """
+ super().__init__()
+ in_ch = 512
+ self.clap_model = model
+ self.clap_model.text_branch = None # to save memory
+ self.freeze = freeze
+ if mlp:
+ self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
+ else:
+ self.lp_layer = nn.Linear(in_ch, out_ch)
+
+ if self.freeze:
+ for param in self.clap_model.parameters():
+ param.requires_grad = False
+
+ if act == 'None':
+ self.act = None
+ elif act == 'relu':
+ self.act = nn.ReLU()
+ elif act == 'elu':
+ self.act = nn.ELU()
+ elif act == 'prelu':
+ self.act = nn.PReLU(num_parameters=in_ch)
+ elif act == 'softmax':
+ self.act = nn.Softmax(dim=-1)
+ elif act == 'sigmoid':
+ self.act = nn.Sigmoid()
+
+ def forward(self, x, mix_lambda=None, device=None):
+ """
+ Args:
+ x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
+ mix_lambda: torch.tensor [batch], the mixup lambda
+ Returns:
+ class_prob: torch.tensor [batch, class_num]
+
+ """
+ # batchnorm cancel grandient
+ if self.freeze:
+ self.clap_model.eval()
+
+ x = self.clap_model.audio_projection(
+ self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"])
+ out = self.lp_layer(x)
+ if self.act is not None:
+ out = self.act(out)
+ return out
diff --git a/src/laion_clap/clap_module/loss.py b/src/laion_clap/clap_module/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..53bbedd959813b072b146c16c14cd96df6cada14
--- /dev/null
+++ b/src/laion_clap/clap_module/loss.py
@@ -0,0 +1,307 @@
+from multiprocessing.sharedctypes import Value
+import torch
+import torch.distributed.nn
+from torch import distributed as dist, nn as nn
+from torch.nn import functional as F
+import numpy as np
+from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def gather_features(
+ audio_features,
+ text_features,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ mlp_loss=False
+):
+ if use_horovod:
+ assert hvd is not None, 'Please install horovod'
+ if gather_with_grad:
+ all_audio_features = hvd.allgather(audio_features)
+ all_text_features = hvd.allgather(text_features)
+ if mlp_loss:
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
+ else:
+ with torch.no_grad():
+ all_audio_features = hvd.allgather(audio_features)
+ all_text_features = hvd.allgather(text_features)
+ if mlp_loss:
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_audio_features = list(all_audio_features.chunk(world_size, dim=0))
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
+ gathered_audio_features[rank] = audio_features
+ gathered_text_features[rank] = text_features
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ if mlp_loss:
+ gathered_audio_features_mlp = list(all_audio_features_mlp.chunk(world_size, dim=0))
+ gathered_text_features_mlp = list(all_text_features_mlp.chunk(world_size, dim=0))
+ gathered_audio_features_mlp[rank] = audio_features_mlp
+ gathered_text_features_mlp[rank] = text_features_mlp
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+ else:
+ # We gather tensors from all gpus
+ if gather_with_grad:
+ all_audio_features = torch.cat(torch.distributed.nn.all_gather(audio_features), dim=0)
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
+ if mlp_loss:
+ all_audio_features_mlp = torch.cat(torch.distributed.nn.all_gather(audio_features_mlp), dim=0)
+ all_text_features_mlp = torch.cat(torch.distributed.nn.all_gather(text_features_mlp), dim=0)
+ else:
+ gathered_audio_features = [torch.zeros_like(audio_features) for _ in range(world_size)]
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
+ dist.all_gather(gathered_audio_features, audio_features)
+ dist.all_gather(gathered_text_features, text_features)
+ if mlp_loss:
+ gathered_audio_features_mlp = [torch.zeros_like(audio_features_mlp) for _ in range(world_size)]
+ gathered_text_features_mlp = [torch.zeros_like(text_features_mlp) for _ in range(world_size)]
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_audio_features[rank] = audio_features
+ gathered_text_features[rank] = text_features
+ if mlp_loss:
+ gathered_audio_features_mlp[rank] = audio_features_mlp
+ gathered_text_features_mlp[rank] = text_features_mlp
+
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ if mlp_loss:
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+ if mlp_loss:
+ return all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp
+ else:
+ return all_audio_features, all_text_features
+
+class ClipLoss(nn.Module):
+
+ def __init__(
+ self,
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ mlp_loss=False,
+ weight_loss_kappa=0,
+ ):
+ super().__init__()
+ self.local_loss = local_loss
+ self.gather_with_grad = gather_with_grad
+ self.cache_labels = cache_labels
+ self.rank = rank
+ self.world_size = world_size
+ self.use_horovod = use_horovod
+ self.mlp_loss = mlp_loss
+ self.weighted_loss = bool(weight_loss_kappa!=0)
+ self.weight_loss_kappa = weight_loss_kappa
+ # cache state
+ self.prev_num_logits = 0
+ self.labels = {}
+
+ def forward(self, audio_features, text_features, logit_scale_a, logit_scale_t=None, audio_features_mlp=None, text_features_mlp=None):
+ device = audio_features.device
+ if self.mlp_loss:
+ if self.world_size > 1:
+ all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = gather_features(
+ audio_features=audio_features,text_features=text_features,
+ audio_features_mlp=audio_features_mlp,text_features_mlp=text_features_mlp,
+ local_loss=self.local_loss,gather_with_grad=self.gather_with_grad,
+ rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod,
+ mlp_loss=self.mlp_loss
+ )
+ if self.local_loss:
+ a_logits_per_audio = logit_scale_a * audio_features @ all_text_features_mlp.T
+ a_logits_per_text = logit_scale_a * text_features_mlp @ all_audio_features.T
+ t_logits_per_audio = logit_scale_t * audio_features_mlp @ all_text_features.T
+ t_logits_per_text = logit_scale_t * text_features @ all_audio_features_mlp.T
+ else:
+ a_logits_per_audio = logit_scale_a * all_audio_features @ all_text_features_mlp.T
+ a_logits_per_text = a_logits_per_audio.T
+ t_logits_per_audio = logit_scale_t * all_audio_features_mlp @ all_text_features.T
+ t_logits_per_text = t_logits_per_audio.T
+ else:
+ a_logits_per_audio = logit_scale_a * audio_features @ text_features_mlp.T
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
+ t_logits_per_audio = logit_scale_t * audio_features_mlp @ text_features.T
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = a_logits_per_audio.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+
+ if not self.weighted_loss:
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels) +
+ F.cross_entropy(a_logits_per_text, labels) +
+ F.cross_entropy(t_logits_per_audio, labels) +
+ F.cross_entropy(t_logits_per_text, labels)
+ ) / 4
+ else:
+ audio_weight = (audio_features@audio_features.T).detach()
+ audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(audio_weight)))).detach()
+ text_weight = (text_features@text_features.T).detach()
+ text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(text_features)))).detach()
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) +
+ F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) +
+ F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) +
+ F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
+ ) / 4
+ else:
+ if self.world_size > 1:
+ all_audio_features, all_text_features = gather_features(
+ audio_features=audio_features,text_features=text_features,
+ local_loss=self.local_loss,gather_with_grad=self.gather_with_grad,
+ rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod,
+ mlp_loss=self.mlp_loss
+ )
+
+ if self.local_loss:
+ logits_per_audio = logit_scale_a * audio_features @ all_text_features.T
+ logits_per_text = logit_scale_a * text_features @ all_audio_features.T
+ else:
+ logits_per_audio = logit_scale_a * all_audio_features @ all_text_features.T
+ logits_per_text = logits_per_audio.T
+ else:
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = logits_per_audio.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+ if not self.weighted_loss:
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels) +
+ F.cross_entropy(logits_per_text, labels)
+ ) / 2
+ else:
+ audio_weight = (all_audio_features@all_audio_features.T).detach()
+ audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(all_audio_features)))).detach()
+ text_weight = (all_text_features@all_text_features.T).detach()
+ text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(all_text_features)))).detach()
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight) +
+ F.cross_entropy(logits_per_text, labels, weight=audio_weight)
+ ) / 2
+ return total_loss
+
+def lp_gather_features(
+ pred,
+ target,
+ world_size=1,
+ use_horovod=False
+):
+ if use_horovod:
+ assert hvd is not None, 'Please install horovod'
+ with torch.no_grad():
+ all_preds = hvd.allgather(pred)
+ all_targets = hvd.allgath(target)
+ else:
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
+
+ dist.all_gather(gathered_preds, pred)
+ dist.all_gather(gathered_targets, target)
+ all_preds = torch.cat(gathered_preds, dim=0)
+ all_targets = torch.cat(gathered_targets, dim=0)
+
+ return all_preds, all_targets
+
+
+def get_map(pred, target):
+ pred = torch.sigmoid(pred).numpy()
+ target = target.numpy()
+ return np.mean(average_precision_score(target, pred, average=None))
+
+def get_acc(pred, target):
+ pred = torch.argmax(pred,1).numpy()
+ target = torch.argmax(target,1).numpy()
+ return accuracy_score(target, pred)
+
+def get_mauc(pred, target):
+ pred = torch.sigmoid(pred).numpy()
+ target = target.numpy()
+ return np.mean(roc_auc_score(target, pred, average=None))
+
+
+class LPMetrics(object):
+ def __init__(self, metric_names = ['map','acc','mauc']):
+ self.metrics = []
+ for name in metric_names:
+ self.metrics.append(self.get_metric(name))
+ self.metric_names = metric_names
+
+ def get_metric(self,name):
+ if name == 'map':
+ return get_map
+ elif name == 'acc':
+ return get_acc
+ elif name == 'mauc':
+ return get_mauc
+ else:
+ raise ValueError(f'the metric should be at least one of [map, acc, mauc]')
+
+ def evaluate_mertics(self, pred, target):
+ metric_dict = {}
+ for i in range(len(self.metric_names)):
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
+ return metric_dict
+
+
+def calc_celoss(pred, target):
+ target = torch.argmax(target, 1).long()
+ return nn.CrossEntropyLoss()(pred, target)
+
+
+class LPLoss(nn.Module):
+
+ def __init__(self, loss_name):
+ super().__init__()
+ if loss_name == 'bce':
+ self.loss_func = nn.BCEWithLogitsLoss()
+ elif loss_name == 'ce':
+ self.loss_func = calc_celoss
+ elif loss_name == 'mse':
+ self.loss_func = nn.MSELoss()
+ else:
+ raise ValueError(f'the loss func should be at least one of [bce, ce, mse]')
+
+ def forward(self, pred, target):
+ loss = self.loss_func(pred, target)
+ return loss
+
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model.py b/src/laion_clap/clap_module/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..60663ec2658cb302f093625c5ce02fc843e6a5bc
--- /dev/null
+++ b/src/laion_clap/clap_module/model.py
@@ -0,0 +1,892 @@
+""" CLAP Model
+
+Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+Adapted to the Audio Task.
+"""
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from email.mime import audio
+from typing import Tuple, Union, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .timm_model import TimmModel
+import logging
+from .utils import freeze_batch_norm_2d
+
+from .pann_model import create_pann_model
+from .htsat import create_htsat_model
+from transformers import BertModel, RobertaModel, BartModel
+from transformers.tokenization_utils_base import BatchEncoding
+
+
+class MLPLayers(nn.Module):
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
+ super(MLPLayers, self).__init__()
+ self.nonlin = nonlin
+ self.dropout = dropout
+
+ sequence = []
+ for u0, u1 in zip(units[:-1], units[1:]):
+ sequence.append(nn.Linear(u0, u1))
+ sequence.append(self.nonlin)
+ sequence.append(nn.Dropout(self.dropout))
+ sequence = sequence[:-2]
+
+ self.sequential = nn.Sequential(*sequence)
+
+ def forward(self, X):
+ X = self.sequential(X)
+ return X
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(
+ OrderedDict(
+ [
+ ("-1", nn.AvgPool2d(stride)),
+ (
+ "0",
+ nn.Conv2d(
+ inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=False,
+ ),
+ ),
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
+ )
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
+ 2, 0, 1
+ ) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x,
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat(
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
+ ),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
+ )
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features**-0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert (
+ unlocked_groups == 0
+ ), "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ def stem(self, x):
+ for conv, bn in [
+ (self.conv1, self.bn1),
+ (self.conv2, self.bn2),
+ (self.conv3, self.bn3),
+ ]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
+ ]
+ )
+ )
+ self.ln_2 = LayerNorm(d_model)
+
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisualTransformer(nn.Module):
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ output_dim: int,
+ act_layer: Callable = nn.GELU,
+ ):
+ super().__init__()
+ self.image_size = image_size
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False,
+ )
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
+ )
+ self.ln_pre = LayerNorm(width)
+
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert (
+ unlocked_groups == 0
+ ), "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat(
+ [
+ self.class_embedding.to(x.dtype)
+ + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
+ ),
+ x,
+ ],
+ dim=1,
+ ) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_branch(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+@dataclass
+class CLAPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ timm_model_name: str = (
+ None # a valid model name overrides layers, width, patch_size
+ )
+ timm_model_pretrained: bool = (
+ False # use (imagenet) pretrained weights for named model
+ )
+ timm_pool: str = (
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ )
+ timm_proj: str = (
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
+ )
+
+
+# Audio Config Class
+@dataclass
+class CLAPAudioCfp:
+ model_type: str = "PANN"
+ model_name: str = "Cnn14"
+ sample_rate: int = 48000
+ # Param
+ audio_length: int = 1024
+ window_size: int = 1024
+ hop_size: int = 1024
+ fmin: int = 50
+ fmax: int = 14000
+ class_num: int = 527
+ mel_bins: int = 64
+ clip_samples: int = 480000
+
+
+@dataclass
+class CLAPTextCfg:
+ context_length: int
+ vocab_size: int
+ width: int
+ heads: int
+ layers: int
+ model_type: str
+
+
+class CLAP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ audio_cfg: CLAPAudioCfp,
+ text_cfg: CLAPTextCfg,
+ quick_gelu: bool = False,
+ enable_fusion: bool = False,
+ fusion_type: str = 'None',
+ joint_embed_shape: int = 512,
+ mlp_act: str = 'relu',
+ ):
+ super().__init__()
+ if isinstance(audio_cfg, dict):
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
+ if isinstance(text_cfg, dict):
+ text_cfg = CLAPTextCfg(**text_cfg)
+
+ self.audio_cfg = audio_cfg
+ self.text_cfg = text_cfg
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+ self.joint_embed_shape = joint_embed_shape
+ self.mlp_act = mlp_act
+
+
+ self.context_length = text_cfg.context_length
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if mlp_act == 'relu':
+ mlp_act_layer = nn.ReLU()
+ elif mlp_act == 'gelu':
+ mlp_act_layer = nn.GELU()
+ else:
+ raise NotImplementedError
+
+ # audio branch
+ # audio branch parameters
+ if audio_cfg.model_type == "PANN":
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
+ elif audio_cfg.model_type == "HTSAT":
+ self.audio_branch = create_htsat_model(audio_cfg, enable_fusion, fusion_type)
+ else:
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
+
+ # text branch
+ # text branch parameters
+ if text_cfg.model_type == "transformer":
+ self.text_branch = Transformer(
+ width=text_cfg.width,
+ layers=text_cfg.layers,
+ heads=text_cfg.heads,
+ act_layer=act_layer,
+ )
+ self.vocab_size = text_cfg.vocab_size
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, text_cfg.width)
+ )
+ self.ln_final = LayerNorm(text_cfg.width)
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape], dropout=0.1)
+ self.text_projection = nn.Sequential(
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
+ )
+ elif text_cfg.model_type == "bert":
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape], dropout=0.1)
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
+ )
+ elif text_cfg.model_type == "roberta":
+ self.text_branch = RobertaModel.from_pretrained('roberta-base')
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape], dropout=0.1)
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
+ )
+ elif text_cfg.model_type == "bart":
+ self.text_branch = BartModel.from_pretrained('facebook/bart-base')
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape], dropout=0.1)
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
+ )
+ else:
+ logging.error(f"Model config for {text_cfg.model_type} not found")
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
+ self.text_branch_type = text_cfg.model_type
+ # text branch parameters
+
+ # audio branch parameters
+ self.audio_transform = MLPLayers(units=[self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape], dropout=0.1)
+
+ # below here is text branch parameters
+
+ # ============================================================================================================
+ self.audio_projection = nn.Sequential(
+ nn.Linear(embed_dim, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
+ )
+
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
+
+ self.init_text_branch_parameters()
+
+ def init_text_branch_parameters(self):
+ if self.text_branch_type == "transformer":
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+ proj_std = (self.text_branch.width**-0.5) * (
+ (2 * self.text_branch.layers) ** -0.5
+ )
+ attn_std = self.text_branch.width**-0.5
+ fc_std = (2 * self.text_branch.width) ** -0.5
+ for block in self.text_branch.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
+ width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
+ elif self.text_branch_type == "bart":
+ width = self.text_branch.shared.weight.shape[-1]
+ else:
+ width = self.text_branch.width
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
+
+ # deprecated
+ # if hasattr(self.visual, 'init_parameters'):
+ # self.visual.init_parameters()
+
+ # if self.text_projection is not None:
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def encode_audio(self, audio, device):
+ return self.audio_branch(audio, mixup_lambda=None, device=device) # mix lambda needs to add
+
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
+ # tmp = {}
+ # for k in x[0].keys():
+ # tmp[k] = []
+ # for i in range(len(x)):
+ # tmp[k].append(x[i][k][:77])
+ # for k in x[0].keys():
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
+ # return tmp
+
+ def encode_text(self, text, device):
+ if self.text_branch_type == "transformer":
+ text = text.to(device=device, non_blocking=True)
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_branch(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
+ elif self.text_branch_type == "bert":
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
+ # text = BatchEncoding(text)
+ x = self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ token_type_ids=text["token_type_ids"].to(
+ device=device, non_blocking=True
+ ),
+ )["pooler_output"]
+ x = self.text_projection(x)
+ elif self.text_branch_type == "roberta":
+ x = self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ )["pooler_output"]
+ x = self.text_projection(x)
+ elif self.text_branch_type == "bart":
+ x = torch.mean(self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ )["encoder_last_hidden_state"],axis=1)
+ x = self.text_projection(x)
+ else:
+ logging.error(f"Model type {self.text_branch_type} not found")
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
+ return x
+
+ def forward(self, audio, text, device=None):
+ """Forward audio and text into the CLAP
+
+ Parameters
+ ----------
+ audio: torch.Tensor (batch_size, audio_length)
+ the time-domain audio input / the batch of mel_spec and longer list.
+ text: torch.Tensor () // need to add
+ the text token input
+ """
+ if device is None:
+ if audio is not None:
+ device = audio.device
+ elif text is not None:
+ device = text.device
+ if audio is None and text is None:
+ # a hack to get the logit scale
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+ elif audio is None:
+ return self.encode_text(text, device=device)
+ elif text is None:
+ return self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
+ audio_features = self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
+ audio_features = F.normalize(audio_features, dim=-1)
+
+ text_features = self.encode_text(
+ text, device=device
+ )
+ # print("text_features", text_features)
+ # print("text_features.shape", text_features.shape)
+ # print("text_features.type", type(text_features))
+ text_features = F.normalize(text_features, dim=-1)
+
+ audio_features_mlp = self.audio_transform(audio_features)
+ text_features_mlp = self.text_transform(text_features)
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
+ return (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ self.logit_scale_a.exp(),
+ self.logit_scale_t.exp(),
+ )
+
+ def get_logit_scale(self):
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+
+ def get_text_embedding(self, data):
+ """Get the text embedding from the model
+
+ Parameters
+ ----------
+ data: torch.Tensor
+ a tensor of text embedding
+
+ Returns
+ ----------
+ text_embed: torch.Tensor
+ a tensor of text_embeds (N, D)
+
+ """
+ device = next(self.parameters()).device
+ for k in data:
+ data[k] = data[k].to(device)
+ text_embeds = self.encode_text(data, device=device)
+ text_embeds = F.normalize(text_embeds, dim=-1)
+
+ return text_embeds
+
+ def get_audio_embedding(self, data):
+ """Get the audio embedding from the model
+
+ Parameters
+ ----------
+ data: a list of dict
+ the audio input dict list from 'get_audio_feature' method
+
+ Returns
+ ----------
+ audio_embed: torch.Tensor
+ a tensor of audio_embeds (N, D)
+
+ """
+ device = next(self.parameters()).device
+ input_dict = {}
+ keys = data[0].keys()
+ for k in keys:
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device)
+ audio_embeds = self.encode_audio(input_dict, device=device)["embedding"]
+ audio_embeds = self.audio_projection(audio_embeds)
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
+ return audio_embeds
+
+
+
+ def audio_infer(self, audio, hopsize=None, device=None):
+ """Forward one audio and produce the audio embedding
+
+ Parameters
+ ----------
+ audio: (audio_length)
+ the time-domain audio input, notice that it must be only one input
+ hopsize: int
+ the overlap hopsize as the sliding window
+
+ Returns
+ ----------
+ output_dict: {
+ key: [n, (embedding_shape)] if "HTS-AT"
+ or
+ key: [(embedding_shape)] if "PANN"
+ }
+ the list of key values of the audio branch
+
+ """
+
+ assert not self.training, "the inference mode must be run at eval stage"
+ output_dict = {}
+ # PANN
+ if self.audio_cfg.model_type == "PANN":
+ audio_input = audio.unsqueeze(dim=0)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
+ elif self.audio_cfg.model_type == "HTSAT":
+ # repeat
+ audio_len = len(audio)
+ k = self.audio_cfg.clip_samples // audio_len
+ if k > 1:
+ audio = audio.repeat(k)
+ audio_len = len(audio)
+
+ if hopsize is None:
+ hopsize = min(hopsize, audio_len)
+
+ if audio_len > self.audio_cfg.clip_samples:
+ audio_input = [
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
+ for pos in range(
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
+ )
+ ]
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
+ audio_input = torch.stack(audio_input)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
+ else:
+ audio_input = audio.unsqueeze(dim=0)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
+
+ return output_dict
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+ "in_proj_bias",
+ "bias_k",
+ "bias_v",
+ ]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+# Ignore the state dict of the vision part
+def build_model_from_openai_state_dict(state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = 'None'):
+
+ embed_dim = model_cfg["embed_dim"]
+ audio_cfg = model_cfg["audio_cfg"]
+ text_cfg = model_cfg["text_cfg"]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith(f"transformer.resblocks")
+ )
+ )
+
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
+ text_cfg = CLAPTextCfg(**text_cfg)
+
+ model = CLAP(
+ embed_dim,
+ audio_cfg=audio_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type
+ )
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
+ pop_keys = list(state_dict.keys())[::]
+ # pop the visual branch saved weights
+ for key in pop_keys:
+ if key.startswith("visual."):
+ state_dict.pop(key, None)
+
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ # not use fp16
+ # convert_weights_to_fp16(model)
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device("cpu")):
+ model.eval()
+ audio_length = model.audio_cfg.audio_length
+ example_audio = torch.ones((batch_size, audio_length), device=device)
+ example_text = torch.zeros(
+ (batch_size, model.context_length), dtype=torch.int, device=device
+ )
+ model = torch.jit.trace_module(
+ model,
+ inputs=dict(
+ forward=(example_audio, example_text),
+ encode_text=(example_text,),
+ encode_image=(example_audio,),
+ ),
+ )
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
+ return model
diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-base.json b/src/laion_clap/clap_module/model_configs/HTSAT-base.json
new file mode 100644
index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/HTSAT-base.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 1024,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "base"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-large.json b/src/laion_clap/clap_module/model_configs/HTSAT-large.json
new file mode 100644
index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/HTSAT-large.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "large"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json b/src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json
new file mode 100644
index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 768,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1536,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "tiny"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-tiny.json b/src/laion_clap/clap_module/model_configs/HTSAT-tiny.json
new file mode 100644
index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/HTSAT-tiny.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 768,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "tiny"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-10.json b/src/laion_clap/clap_module/model_configs/PANN-10.json
new file mode 100644
index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-10.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 1024,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn10"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json
new file mode 100644
index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 18000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json
new file mode 100644
index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 960000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 360,
+ "fmin": 50,
+ "fmax": 8000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json b/src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json
new file mode 100644
index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 4
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json b/src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json
new file mode 100644
index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1536,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-14.json b/src/laion_clap/clap_module/model_configs/PANN-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-14.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/PANN-6.json b/src/laion_clap/clap_module/model_configs/PANN-6.json
new file mode 100644
index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/PANN-6.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 512,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn6"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/RN101-quickgelu.json b/src/laion_clap/clap_module/model_configs/RN101-quickgelu.json
new file mode 100644
index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/RN101-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/RN101.json b/src/laion_clap/clap_module/model_configs/RN101.json
new file mode 100644
index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/RN101.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/RN50-quickgelu.json b/src/laion_clap/clap_module/model_configs/RN50-quickgelu.json
new file mode 100644
index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/RN50-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 1024,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
diff --git a/src/laion_clap/clap_module/model_configs/RN50.json b/src/laion_clap/clap_module/model_configs/RN50.json
new file mode 100644
index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/RN50.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/RN50x16.json b/src/laion_clap/clap_module/model_configs/RN50x16.json
new file mode 100644
index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/RN50x16.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 384,
+ "layers": [
+ 6,
+ 8,
+ 18,
+ 8
+ ],
+ "width": 96,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/RN50x4.json b/src/laion_clap/clap_module/model_configs/RN50x4.json
new file mode 100644
index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/RN50x4.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 288,
+ "layers": [
+ 4,
+ 6,
+ 10,
+ 6
+ ],
+ "width": 80,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/ViT-B-16.json b/src/laion_clap/clap_module/model_configs/ViT-B-16.json
new file mode 100644
index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/ViT-B-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json b/src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json
new file mode 100644
index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/ViT-B-32.json b/src/laion_clap/clap_module/model_configs/ViT-B-32.json
new file mode 100644
index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/model_configs/ViT-L-14.json b/src/laion_clap/clap_module/model_configs/ViT-L-14.json
new file mode 100644
index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241
--- /dev/null
+++ b/src/laion_clap/clap_module/model_configs/ViT-L-14.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/src/laion_clap/clap_module/openai.py b/src/laion_clap/clap_module/openai.py
new file mode 100644
index 0000000000000000000000000000000000000000..9911b6e135e51970177fcac067c12192b0b57c1c
--- /dev/null
+++ b/src/laion_clap/clap_module/openai.py
@@ -0,0 +1,129 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import Union, List
+
+import torch
+
+from .model import build_model_from_openai_state_dict
+from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_tag_models('openai')
+
+
+def load_openai_model(
+ name: str,
+ model_cfg,
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
+ jit=True,
+ cache_dir=os.path.expanduser("~/.cache/clip"),
+ enable_fusion: bool = False,
+ fusion_type: str = 'None'
+):
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLAP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if get_pretrained_url(name, 'openai'):
+ model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir)
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ try:
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device)
+
+ if str(device) == "cpu":
+ model.float()
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_audio)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_audio)
+ patch_float(model.encode_text)
+ model.float()
+
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
+ return model
diff --git a/src/laion_clap/clap_module/pann_model.py b/src/laion_clap/clap_module/pann_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..109db5f418a0bad32cae2452742589ff52a19b85
--- /dev/null
+++ b/src/laion_clap/clap_module/pann_model.py
@@ -0,0 +1,543 @@
+# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
+# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
+# Some layers are re-designed for CLAP
+import os
+os.environ['NUMBA_CACHE_DIR'] = '/tmp/'
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from .utils import do_mixup, interpolate, pad_framewise_output
+from .feature_fusion import iAFF, AFF, DAF
+
+
+def init_layer(layer):
+ """Initialize a Linear or Convolutional layer. """
+ nn.init.xavier_uniform_(layer.weight)
+
+ if hasattr(layer, 'bias'):
+ if layer.bias is not None:
+ layer.bias.data.fill_(0.)
+
+
+def init_bn(bn):
+ """Initialize a Batchnorm layer. """
+ bn.bias.data.fill_(0.)
+ bn.weight.data.fill_(1.)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels):
+
+ super(ConvBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3), stride=(1, 1),
+ padding=(1, 1), bias=False)
+
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3), stride=(1, 1),
+ padding=(1, 1), bias=False)
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_layer(self.conv2)
+ init_bn(self.bn1)
+ init_bn(self.bn2)
+
+
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
+
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ x = F.relu_(self.bn2(self.conv2(x)))
+ if pool_type == 'max':
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == 'avg':
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == 'avg+max':
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception('Incorrect argument!')
+
+ return x
+
+
+class ConvBlock5x5(nn.Module):
+ def __init__(self, in_channels, out_channels):
+
+ super(ConvBlock5x5, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(5, 5), stride=(1, 1),
+ padding=(2, 2), bias=False)
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_bn(self.bn1)
+
+
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
+
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ if pool_type == 'max':
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == 'avg':
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == 'avg+max':
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception('Incorrect argument!')
+
+ return x
+
+
+class AttBlock(nn.Module):
+ def __init__(self, n_in, n_out, activation='linear', temperature=1.):
+ super(AttBlock, self).__init__()
+
+ self.activation = activation
+ self.temperature = temperature
+ self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
+ self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True)
+
+ self.bn_att = nn.BatchNorm1d(n_out)
+ self.init_weights()
+
+ def init_weights(self):
+ init_layer(self.att)
+ init_layer(self.cla)
+ init_bn(self.bn_att)
+
+ def forward(self, x):
+ # x: (n_samples, n_in, n_time)
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
+ cla = self.nonlinear_transform(self.cla(x))
+ x = torch.sum(norm_att * cla, dim=2)
+ return x, norm_att, cla
+
+ def nonlinear_transform(self, x):
+ if self.activation == 'linear':
+ return x
+ elif self.activation == 'sigmoid':
+ return torch.sigmoid(x)
+
+
+class Cnn14(nn.Module):
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
+ fmax, classes_num, enable_fusion=False, fusion_type='None'):
+
+ super(Cnn14, self).__init__()
+
+ window = 'hann'
+ center = True
+ pad_mode = 'reflect'
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
+ win_length=window_size, window=window, center=center, pad_mode=pad_mode,
+ freeze_parameters=True)
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
+ n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
+ freeze_parameters=True)
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
+ freq_drop_width=8, freq_stripes_num=2)
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ if (self.enable_fusion) and (self.fusion_type == 'channel_map'):
+ self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
+ else:
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
+
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
+ self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
+
+ if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']):
+ self.mel_conv1d = nn.Sequential(
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+ nn.BatchNorm1d(64) # No Relu
+ )
+ if self.fusion_type == 'daf_1d':
+ self.fusion_model = DAF()
+ elif self.fusion_type == 'aff_1d':
+ self.fusion_model = AFF(channels=64, type='1D')
+ elif self.fusion_type == 'iaff_1d':
+ self.fusion_model = iAFF(channels=64, type='1D')
+
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
+ self.mel_conv2d = nn.Sequential(
+ nn.Conv2d(1, 64, kernel_size=(5,5), stride=(6, 2), padding=(2,2)),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True)
+ )
+
+ if self.fusion_type == 'daf_2d':
+ self.fusion_model = DAF()
+ elif self.fusion_type == 'aff_2d':
+ self.fusion_model = AFF(channels=64, type='2D')
+ elif self.fusion_type == 'iaff_2d':
+ self.fusion_model = iAFF(channels=64, type='2D')
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ if self.enable_fusion and input["longer"].sum() == 0:
+ # if no audio is longer than 10s, then randomly select one audio to be longer
+ input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
+
+ if not self.enable_fusion:
+ x = self.spectrogram_extractor(input['waveform'].to(device=device, non_blocking=True)) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ else:
+ longer_list = input["longer"].to(device=device, non_blocking=True)
+ x = input["mel_fusion"].to(device=device, non_blocking=True)
+ longer_list_idx = torch.where(longer_list)[0]
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']:
+ new_x = x[:,0:1,:,:].clone().contiguous()
+ # local processing
+ if len(longer_list_idx) > 0:
+ fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous()
+ FB,FC,FT,FF = fusion_x_local.size()
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous()
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
+ fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1))
+ fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2)
+ if fusion_x_local.size(-1) < FT:
+ fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1)
+ else:
+ fusion_x_local = fusion_x_local[:,:,:FT]
+ # 1D fusion
+ new_x = new_x.squeeze(1).permute((0,2,1)).contiguous()
+ new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local)
+ x = new_x.permute((0,2,1)).contiguous()[:,None,:,:]
+ else:
+ x = new_x
+ elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']:
+ x = x # no change
+
+ if self.training:
+ x = self.spec_augmenter(x)
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+ if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']):
+ global_x = x[:,0:1,:,:]
+
+ # global processing
+ B, C, H, W = global_x.shape
+ global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type='avg')
+ if len(longer_list_idx) > 0:
+ local_x = x[longer_list_idx,1:,:,:].contiguous()
+ TH = global_x.size(-2)
+ # local processing
+ B, C, H, W = local_x.shape
+ local_x = local_x.view(B*C,1,H,W)
+ local_x = self.mel_conv2d(local_x)
+ local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3))
+ local_x = local_x.permute((0,2,1,3,4)).contiguous().flatten(2,3)
+ TB,TC,_,TW = local_x.size()
+ if local_x.size(-2) < TH:
+ local_x = torch.cat([local_x, torch.zeros((TB,TC,TH-local_x.size(-2),TW), device=global_x.device)], dim=-2)
+ else:
+ local_x = local_x[:,:,:TH,:]
+
+ global_x[longer_list_idx] = self.fusion_model(global_x[longer_list_idx],local_x)
+ x = global_x
+ else:
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
+
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 32)
+
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output}
+ return output_dict
+
+
+class Cnn6(nn.Module):
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
+ fmax, classes_num, enable_fusion=False, fusion_type='None'):
+
+ super(Cnn6, self).__init__()
+
+ window = 'hann'
+ center = True
+ pad_mode = 'reflect'
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
+ win_length=window_size, window=window, center=center, pad_mode=pad_mode,
+ freeze_parameters=True)
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
+ n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
+ freeze_parameters=True)
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
+ freq_drop_width=8, freq_stripes_num=2)
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
+
+ self.fc1 = nn.Linear(512, 512, bias=True)
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 16)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output}
+
+ return output_dict
+
+
+class Cnn10(nn.Module):
+ def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
+ fmax, classes_num, enable_fusion=False, fusion_type='None'):
+
+ super(Cnn10, self).__init__()
+
+ window = 'hann'
+ center = True
+ pad_mode = 'reflect'
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size,
+ win_length=window_size, window=window, center=center, pad_mode=pad_mode,
+ freeze_parameters=True)
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size,
+ n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db,
+ freeze_parameters=True)
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
+ freq_drop_width=8, freq_stripes_num=2)
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+
+ self.fc1 = nn.Linear(1024, 1024, bias=True)
+ self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 32)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output}
+
+ return output_dict
+
+
+def create_pann_model(audio_cfg, enable_fusion=False, fusion_type='None'):
+ try:
+ ModelProto = eval(audio_cfg.model_name)
+ model = ModelProto(
+ sample_rate = audio_cfg.sample_rate,
+ window_size = audio_cfg.window_size,
+ hop_size =audio_cfg.hop_size,
+ mel_bins = audio_cfg.mel_bins,
+ fmin = audio_cfg.fmin,
+ fmax = audio_cfg.fmax,
+ classes_num = audio_cfg.class_num,
+ enable_fusion = enable_fusion,
+ fusion_type = fusion_type
+ )
+ return model
+ except:
+ raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.')
+
diff --git a/src/laion_clap/clap_module/pretrained.py b/src/laion_clap/clap_module/pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..723619a9fd511cf8619def49c4631ec701891b93
--- /dev/null
+++ b/src/laion_clap/clap_module/pretrained.py
@@ -0,0 +1,147 @@
+import hashlib
+import os
+import urllib
+import warnings
+
+from tqdm import tqdm
+
+_RN50 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
+)
+
+_RN50_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
+)
+
+_RN101 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
+)
+
+_RN101_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
+)
+
+_RN50x4 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+)
+
+_RN50x16 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+)
+
+_RN50x64 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+)
+
+_VITB32 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB32_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB16 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+)
+
+_VITL14 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+)
+
+_PRETRAINED = {
+ "RN50": _RN50,
+ "RN50-quickgelu": _RN50_quickgelu,
+ "RN101": _RN101,
+ "RN101-quickgelu": _RN101_quickgelu,
+ "RN50x4": _RN50x4,
+ "RN50x16": _RN50x16,
+ "ViT-B-32": _VITB32,
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "ViT-B-16": _VITB16,
+ "ViT-L-14": _VITL14,
+}
+
+
+def list_pretrained(as_str: bool = False):
+ """ returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
+
+
+def list_pretrained_tag_models(tag: str):
+ """ return all models having the specified pretrain tag """
+ models = []
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_model_tags(model: str):
+ """ return all pretrain tags for the specified model architecture """
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def get_pretrained_url(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return ''
+ model_pretrained = _PRETRAINED[model]
+ if tag not in model_pretrained:
+ return ''
+ return model_pretrained[tag]
+
+
+def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if 'openaipublic' in url:
+ expected_sha256 = url.split("/")[-2]
+ else:
+ expected_sha256 = ''
+
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
diff --git a/src/laion_clap/clap_module/timm_model.py b/src/laion_clap/clap_module/timm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..071dd148c772f398e87ecbfc836dcfa4a3ae01af
--- /dev/null
+++ b/src/laion_clap/clap_module/timm_model.py
@@ -0,0 +1,106 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+from collections import OrderedDict
+
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
+except ImportError as e:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """ timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(
+ self,
+ model_name,
+ embed_dim,
+ image_size=224,
+ pool='avg',
+ proj='linear',
+ drop=0.,
+ pretrained=False):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
+ feat_size = self.trunk.default_cfg.get('pool_size', None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ('abs_attn', 'rot_attn'):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool='')
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == 'abs_attn':
+ head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
+ prev_chs = embed_dim
+ elif pool == 'rot_attn':
+ head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, 'projection layer needed if non-attention pooling is used.'
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == 'linear':
+ head_layers['drop'] = nn.Dropout(drop)
+ head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
+ elif proj == 'mlp':
+ head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """ lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError(
+ 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/src/laion_clap/clap_module/tokenizer.py b/src/laion_clap/clap_module/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b4a238b987ce66f2932b11451d916e40816b8a3
--- /dev/null
+++ b/src/laion_clap/clap_module/tokenizer.py
@@ -0,0 +1,180 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ if not special_tokens:
+ special_tokens = ['', '']
+ else:
+ special_tokens = ['', ''] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t:t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder[""]
+ eot_token = _tokenizer.encoder[""]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/src/laion_clap/clap_module/transform.py b/src/laion_clap/clap_module/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..7014c926f153a351d2256c869c67c02d57b30913
--- /dev/null
+++ b/src/laion_clap/clap_module/transform.py
@@ -0,0 +1,30 @@
+from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
+ CenterCrop
+
+
+def _convert_to_rgb(image):
+ return image.convert('RGB')
+
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711)
+):
+ normalize = Normalize(mean=mean, std=std)
+ if is_train:
+ return Compose([
+ RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ])
+ else:
+ return Compose([
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ])
diff --git a/src/laion_clap/clap_module/utils.py b/src/laion_clap/clap_module/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee10b9310944e9de9f8e1db1a4104defd0423744
--- /dev/null
+++ b/src/laion_clap/clap_module/utils.py
@@ -0,0 +1,389 @@
+import numpy as np
+import torch
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+import logging
+import h5py
+from tqdm import tqdm
+import random
+import json
+import os
+import pathlib
+
+# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
+dataset_split = {
+ "audiocaps": ["train", "valid", "test"],
+ "audioset": ["balanced_train", "unbalanced_train", "eval"],
+ "BBCSoundEffects": ["train", "test"],
+ "Clotho": ["train", "test", "valid"],
+ "free_to_use_sounds": ["train", "test"],
+ "paramount_motion": ["train", "test"],
+ "sonniss_game_effects": ["train", "test"],
+ "wesoundeffects": ["train", "test"],
+ "MACS": ["train", "test"],
+ "freesound": ["train", "test"],
+ "FSD50K": ["train", "test", "valid"],
+ "fsd50k_class_label": ["train", "test", "valid"],
+ "esc50": ["train", "test"],
+ "ESC50_1": ["train", "test"],
+ "ESC50_2": ["train", "test"],
+ "ESC50_3": ["train", "test"],
+ "ESC50_4": ["train", "test"],
+ "ESC50_5": ["train", "test"],
+ "audiostock": ["train", "test"],
+ "freesound_no_overlap_noesc50": ["train", "test"],
+ "epidemic_sound_effects": ["train", "test"],
+ "VGGSound": ["train", "test"],
+ "urbansound8k_class_label": ["train", "test"],
+ "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
+ "audioset_t5_debiased": ["balanced_train", "unbalanced_train", "eval"],
+ "epidemic_sound_effects_t5": ["train", "test"],
+ "epidemic_sound_effects_t5_debiased": ["train", "test"],
+ "WavText5K": ["train", "test"],
+ "esc50_no_overlap": ["train", "test"],
+ "usd8k_no_overlap": ["train", "test"],
+ "fsd50k_200_class_label": ["train", "test", "valid"],
+ "fma_full": ["train", "test"],
+ "Genius": ["train", "test"],
+ "Jamendo": ["train", "test"],
+ "juno": ["train", "test"],
+ "CMU_Arctic": ["train", "test"],
+ "ravdess": ["train", "test"],
+ "Europarl-st": ["train", "test"],
+ "common_voice": ["train", "test"],
+ "Jamendo_16bit": ["train", "test"],
+ "genius_16bit_128": ["train", "test"],
+ "juno_16bit": ["train", "test"],
+ "fma_full_16bit_128": ["train", "test"],
+ "GTZAN": ["train", "test"],
+ }
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=""):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(
+ module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
+ ):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = ".".join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+def exist(dataset_name, dataset_type):
+ """
+ Check if dataset exists
+ """
+ if dataset_type in dataset_split[dataset_name]:
+ return True
+ else:
+ return False
+
+
+def get_tar_path_from_dataset_name(
+ dataset_names,
+ dataset_types,
+ islocal,
+ dataset_path,
+ proportion=1,
+ full_dataset=None
+):
+ """
+ Get tar path from dataset name and type
+ """
+ output = []
+ for n in dataset_names:
+ if full_dataset is not None and n in full_dataset:
+ current_dataset_types = dataset_split[n]
+ else:
+ current_dataset_types = dataset_types
+ for s in current_dataset_types:
+ tmp = []
+ if islocal:
+ sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
+ if not os.path.exists(sizefilepath_):
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ else:
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ if not os.path.exists(sizefilepath_):
+ continue
+ sizes = json.load(open(sizefilepath_, "r"))
+ for k in sizes.keys():
+ if islocal:
+ tmp.append(f"{dataset_path}/{n}/{s}/{k}")
+ else:
+ tmp.append(
+ f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
+ )
+ if proportion != 1:
+ tmp = random.sample(tmp, int(proportion * len(tmp)))
+ output.append(tmp)
+ return sum(output, [])
+
+
+def get_tar_path_from_txts(txt_path, islocal, proportion=1):
+ """
+ Get tar path from txt path
+ """
+ if isinstance(txt_path, (list, tuple)):
+ return sum(
+ [
+ get_tar_path_from_txts(
+ txt_path[i], islocal=islocal, proportion=proportion
+ )
+ for i in range(len(txt_path))
+ ],
+ [],
+ )
+ if isinstance(txt_path, str):
+ with open(txt_path) as f:
+ lines = f.readlines()
+ if islocal:
+ lines = [
+ lines[i]
+ .split("\n")[0]
+ .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
+ for i in range(len(lines))
+ ]
+ else:
+ lines = [
+ lines[i].split("\n")[0].replace(".tar", ".tar -")
+ for i in range(len(lines))
+ ]
+ if proportion != 1:
+ print("Sampling tars with proportion of {}".format(proportion))
+ lines = random.sample(lines, int(proportion * len(lines)))
+ return lines
+
+
+def get_mix_lambda(mixup_alpha, batch_size):
+ mixup_lambdas = [
+ np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
+ ]
+ return np.array(mixup_lambdas).astype(np.float32)
+
+
+def do_mixup(x, mixup_lambda):
+ """
+ Args:
+ x: (batch_size , ...)
+ mixup_lambda: (batch_size,)
+ Returns:
+ out: (batch_size, ...)
+ """
+ out = (
+ x.transpose(0, -1) * mixup_lambda
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
+ ).transpose(0, -1)
+ return out
+
+
+def interpolate(x, ratio):
+ """Interpolate data in time domain. This is used to compensate the
+ resolution reduction in downsampling of a CNN.
+
+ Args:
+ x: (batch_size, time_steps, classes_num)
+ ratio: int, ratio to interpolate
+ Returns:
+ upsampled: (batch_size, time_steps * ratio, classes_num)
+ """
+ (batch_size, time_steps, classes_num) = x.shape
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
+ return upsampled
+
+
+def pad_framewise_output(framewise_output, frames_num):
+ """Pad framewise_output to the same length as input frames. The pad value
+ is the same as the value of the last frame.
+ Args:
+ framewise_output: (batch_size, frames_num, classes_num)
+ frames_num: int, number of frames to pad
+ Outputs:
+ output: (batch_size, frames_num, classes_num)
+ """
+ pad = framewise_output[:, -1:, :].repeat(
+ 1, frames_num - framewise_output.shape[1], 1
+ )
+ """tensor for padding"""
+
+ output = torch.cat((framewise_output, pad), dim=1)
+ """(batch_size, frames_num, classes_num)"""
+
+
+def process_ipc(index_path, classes_num, filename):
+ # load data
+ logging.info("Load Data...............")
+ ipc = [[] for _ in range(classes_num)]
+ with h5py.File(index_path, "r") as f:
+ for i in tqdm(range(len(f["target"]))):
+ t_class = np.where(f["target"][i])[0]
+ for t in t_class:
+ ipc[t].append(i)
+ print(ipc)
+ np.save(filename, ipc)
+ logging.info("Load Data Succeed...............")
+
+
+def save_to_dict(s, o_={}):
+ sp = s.split(": ")
+ o_.update({sp[0]: float(sp[1])})
+ return o_
+
+
+def get_data_from_log(txt_path):
+ """
+ Output dictionary from out.txt log file
+ """
+ with open(txt_path) as f:
+ lines = f.readlines()
+ val_data = {}
+ train_data = {}
+ train_losses = []
+ train_losses_epoch = []
+ for i in range(len(lines)):
+ if "| INFO |" in lines[i]:
+ if "Eval Epoch" in lines[i]:
+ if "val_loss" in lines[i]:
+ # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
+ line = lines[i].split("Eval Epoch: ")[-1]
+ num_epoch = int(line.split(" ")[0].split(" ")[0])
+ d = {
+ line.split(" ")[0]
+ .split(" ")[1]
+ .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
+ }
+ for i in range(1, len(line.split(" "))):
+ d = save_to_dict(line.split(" ")[i], d)
+ val_data[num_epoch] = d
+ elif "Train Epoch" in lines[i]:
+ num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
+ loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
+ train_losses.append(loss)
+ train_losses_epoch.append(num_epoch)
+ for i in range(len(train_losses)):
+ train_data[i] = {
+ "num_epoch": train_losses_epoch[i],
+ "train_loss": train_losses[i],
+ }
+ return train_data, val_data
+
+
+def save_p(obj, filename):
+ import pickle
+
+ try:
+ from deepdiff import DeepDiff
+ except:
+ os.system("pip install deepdiff")
+ from deepdiff import DeepDiff
+ with open(filename, "wb") as file:
+ pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
+ with open(filename, "rb") as file:
+ z = pickle.load(file)
+ assert (
+ DeepDiff(obj, z, ignore_string_case=True) == {}
+ ), "there is something wrong with the saving process"
+ return
+
+
+def load_p(filename):
+ import pickle
+
+ with open(filename, "rb") as file:
+ z = pickle.load(file)
+ return z
+
+
+def save_json(data, name="data.json"):
+ import json
+ with open(name, 'w') as fp:
+ json.dump(data, fp)
+ return
+
+
+def load_json(name):
+ import json
+ with open(name, 'r') as fp:
+ data = json.load(fp)
+ return data
+
+
+from multiprocessing import Process, Manager
+from multiprocessing import Process, Value, Array
+from ctypes import c_wchar
+
+
+def load_class_label(path):
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+ out = None
+ if path is not None:
+ if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
+ out = load_p(path)
+ elif pathlib.Path(path).suffix in [".json", ".txt"]:
+ out = load_json(path)
+ elif pathlib.Path(path).suffix in [".npy", ".npz"]:
+ out = np.load(path)
+ elif pathlib.Path(path).suffix in [".csv"]:
+ import pandas as pd
+ out = pd.read_csv(path)
+ return out
+ # if out is None:
+ # return None
+ # else:
+ # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
+ # val = Array('i', out.values(), lock=False)
+ # return (key, val)
+
+
+from torch import optim
+
+
+def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
+ if optimizer_name.lower() == "adamw":
+ optimizer = optim.AdamW(
+ params, lr=lr, betas=betas, eps=eps
+ )
+ elif optimizer_name.lower() == "sgd":
+ optimizer = optim.SGD(
+ params, lr=lr, momentum=momentum
+ )
+ elif optimizer_name.lower() == "adam":
+ optimizer = optim.Adam(
+ params, lr=lr, betas=betas, eps=eps
+ )
+ else:
+ raise ValueError("optimizer name is not correct")
+ return optimizer
diff --git a/src/laion_clap/clap_module/version.py b/src/laion_clap/clap_module/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc79d63d5430b972ac6ec1c4bfea9af80922da4d
--- /dev/null
+++ b/src/laion_clap/clap_module/version.py
@@ -0,0 +1 @@
+__version__ = '0.2.1'
diff --git a/src/laion_clap/evaluate/__init__.py b/src/laion_clap/evaluate/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/laion_clap/evaluate/eval_dcase.py b/src/laion_clap/evaluate/eval_dcase.py
new file mode 100644
index 0000000000000000000000000000000000000000..c615651f2d96f7e34d109e9c3dbb8abc7275065f
--- /dev/null
+++ b/src/laion_clap/evaluate/eval_dcase.py
@@ -0,0 +1,150 @@
+import torch
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from open_clip import create_model
+from open_clip import tokenize
+import glob
+import json
+import librosa
+from tqdm import tqdm
+import numpy as np
+import os
+from laion_clap.training.params import parse_args
+
+
+def get_output_from_single_audio(audio, text, model, device):
+
+ # audio_embedding = model.audio_infer(audio, hopsize=5 * 48000, key="embedding", device=device)['embedding']
+ # if audio_embedding.ndim > 1:
+ # audio_embedding = audio_embedding.mean(dim=0, keepdim=True)
+ # else:
+ # audio_embedding = audio_embedding.unsqueeze(0)
+ audio_features = model(audio, None, device)
+ audio_features = F.normalize(audio_features, dim=-1)
+ text_features = model(None, text, device=device)
+ text_features = F.normalize(text_features, dim=-1)
+
+ # CHANGE: before normalize or after
+ audio_features_mlp = model.audio_transform(audio_features)
+ text_features_mlp = model.text_transform(text_features)
+ return audio_features, text_features, audio_features_mlp, text_features_mlp, model.logit_scale_a.exp(), model.logit_scale_t.exp()
+
+
+def get_metrics(text_to_audio_logits):
+ metrics = {}
+
+ # repeat ground truth 5 times because Clotho has 5 text for 1 audio
+ ground_truth = torch.repeat_interleave(torch.arange(len(text_features) // 5), 5).view(-1, 1)
+
+ ranking = torch.argsort(text_to_audio_logits, descending=True)
+ preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread
+ preds = preds.detach().cpu().numpy()
+ metrics[f"mean_rank"] = preds.mean() + 1
+ metrics[f"median_rank"] = np.floor(np.median(preds)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"R@{k}"] = np.mean(preds < k)
+ # map@10
+ metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
+ return metrics
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ model_path = args.pretrained
+
+ clotho_test_preprocessed_dir = "/fsx/yusong/clotho_test_set/test"
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ audio_features_ensemble_all = []
+ text_features_ensemble_all = []
+ audio_features_mlp_ensemble_all = []
+ text_features_mlp_ensemble_all = []
+ logit_scale_a_ensemble_all = []
+ logit_scale_t_ensemble_all = []
+
+
+ device = torch.device('cuda')
+ model, clap_model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False
+ )
+
+ # load model
+ checkpoint = torch.load(model_path, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module."):]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ # take every 5th file because clotho has 5 texts for 1 audio
+ test_file_list = sorted(glob.glob(f"{clotho_test_preprocessed_dir}/*.flac"))
+
+ audio_features_all = []
+ text_features_all = []
+ audio_features_mlp_all = []
+ text_features_mlp_all = []
+ logit_scale_a_all = []
+ logit_scale_t_all = []
+
+ with torch.no_grad():
+ for file_path in tqdm(test_file_list):
+ json_path = file_path.replace(".flac", ".json")
+ with open(json_path, "r") as f:
+ json_data = json.load(f)
+ audio, sr = librosa.load(file_path, sr=48000, mono=True)
+ audio = torch.from_numpy(audio).to(device)
+ audio = {'waveform': audio.unsqueeze(0), 'sample_rate': sr}
+ text = json_data["text"]
+
+ if args.tmodel == "transformer":
+ from open_clip import tokenize
+ text = tokenize(text)
+ else:
+ from laion_clap.training.data import tokenizer
+ text = tokenizer(text, tmodel=args.tmodel) # 5 texts for each audio
+
+ audio_features, text_features, audio_features_mlp, text_features_mlp, logit_scale_a, logit_scale_t = \
+ get_output_from_single_audio(audio, text, model, device)
+
+ audio_features_all.append(audio_features.detach().cpu())
+ text_features_all.append(text_features.detach().cpu())
+ audio_features_mlp_all.append(audio_features_mlp.detach().cpu())
+ text_features_mlp_all.append(text_features_mlp.detach().cpu())
+ logit_scale_a_all.append(logit_scale_a.detach().cpu())
+ logit_scale_t_all.append(logit_scale_t.detach().cpu())
+
+ audio_features = torch.cat(audio_features_all)
+ text_features = torch.cat(text_features_all)
+ logit_scale_a = logit_scale_a_all[0]
+
+ logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ metrics = get_metrics(
+ logits_per_text
+ )
+
+ print(metrics)
diff --git a/src/laion_clap/evaluate/eval_linear_probe.py b/src/laion_clap/evaluate/eval_linear_probe.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5123451b503de7c54fac58e99b3507439992171
--- /dev/null
+++ b/src/laion_clap/evaluate/eval_linear_probe.py
@@ -0,0 +1,515 @@
+'''
+Evalute the linear probe performance on different checkpoints
+'''
+import logging
+import os
+import random
+from datetime import datetime
+import copy
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.cuda.amp import GradScaler
+import glob
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+try:
+ import torch.utils.tensorboard as tensorboard
+except ImportError:
+ tensorboard = None
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from clap_module import create_model_and_transforms, trace_model, create_model
+from training.data import get_data
+from training.params import parse_args
+from training.distributed import is_master, init_distributed_device, world_info_from_env
+from training.logger import setup_logging
+from training.scheduler import cosine_lr
+from training.lp_main import config_lp_optimizer
+from training.lp_train import train_one_epoch, evaluate
+from clap_module.utils import get_tar_path_from_dataset_name, dataset_split
+from clap_module.utils import load_p, load_class_label
+from clap_module.linear_probe import LinearProbe
+
+def maintain_ckpts(args, startidx, all_idx_len):
+ for i in reversed(range(startidx, all_idx_len)):
+ if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
+ os.rename(
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
+ )
+ if os.path.exists(
+ os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
+ ):
+ os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
+ return
+
+
+def update_top_k_performance(
+ new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True, pretrain_epoch=0
+):
+ """
+ Record the top-k performance of the current epoch.
+ current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
+ """
+ if isinstance(new_metrics_inputs, (list, tuple)):
+ new_metrics_inputs = np.mean(new_metrics_inputs)
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ pretrain_epoch=pretrain_epoch
+ )
+ elif isinstance(new_metrics_inputs, dict):
+ new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ pretrain_epoch=pretrain_epoch
+ )
+ elif isinstance(new_metrics_inputs, (float, int)):
+ update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
+ sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
+ sorted_values = sorted(
+ current_top_k_ckpt_metrics.values(), reverse=bignumbetter
+ )
+ sorted_values_ = copy.deepcopy(sorted_values)
+ sorted_values.append(new_metrics_inputs)
+ sorted_values = sorted(sorted_values, reverse=bignumbetter)
+ sorted_values = sorted_values[:-1]
+
+ if sorted_values == sorted_values_:
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+ else:
+ for i in range(len(sorted_keys)):
+ if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
+ current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
+ update_flag[sorted_keys[i]] = True
+ for i in range(len(update_flag)):
+ if update_flag[i]:
+ maintain_ckpts(args, i, len(sorted_keys))
+ torch.save(
+ ckpt,
+ os.path.join(args.checkpoint_path, f"pretrain_epoch_{pretrain_epoch}_lp_epoch_top_{i}.pt"),
+ )
+ break
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+
+
+# def updateifNone(a, b):
+# a = b if None else a
+# return a
+
+
+def is_pretrained_params(n):
+ return (
+ n.startswith("clap_model.transformer")
+ or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
+ or n.startswith("clap_model.token_embedding")
+ or n.startswith("clap_model.ln_final")
+ or n.startswith("clap_model.logit_scale_t")
+ )
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
+
+def main():
+ args = parse_args()
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ args.amodel = args.amodel.replace("/", "-")
+
+ pretrained_ckpts = sorted(glob.glob(os.path.join(args.pretrained, "*.pt")), key=os.path.getmtime)
+
+ if args.name is None:
+ args.name = "-".join(
+ [
+ datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
+ f"linear_probe"
+ f"model_{args.amodel}",
+ f"lr_{args.lr}",
+ f"b_{args.batch_size}",
+ f"j_{args.workers}",
+ f"p_{args.precision}",
+ ]
+ )
+
+ # discover initial world args early so we can log properly
+ args.distributed = False
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+
+ if args.remotedata and is_master(args):
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+ args.log_path = None
+ if is_master(args, local=args.log_local):
+ log_base_path = os.path.join(args.logs, args.name)
+ os.makedirs(log_base_path, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path, log_filename)
+
+ # avoid log dir in same name:
+ postfix = 0
+ while os.path.exists(args.log_path):
+ postfix += 1
+ log_base_path_new = log_base_path+'-'+str(postfix)
+ os.makedirs(log_base_path_new, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path_new, log_filename)
+ # print(
+ # "Error. Experiment already exists. Use --name {} to specify a new experiment."
+ # )
+ # return -1
+
+ # Set logger
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ setup_logging(args.log_path, args.log_level)
+
+ # fully initialize distributed device environment
+ device = init_distributed_device(args)
+
+ args.wandb = "wandb" in args.report_to or "all" in args.report_to
+ args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
+ if is_master(args):
+ args.tensorboard_path = (
+ os.path.join(args.logs, args.name, "tensorboard")
+ if args.tensorboard
+ else ""
+ )
+ args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
+ for dirname in [args.tensorboard_path, args.checkpoint_path]:
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ else:
+ args.tensorboard_path = ""
+ args.checkpoint_path = ""
+
+ if args.copy_codebase:
+ copy_codebase(args)
+
+ assert args.precision in ["amp", "fp16", "fp32"]
+ if args.precision == "fp16":
+ logging.warning(
+ "It is recommended to use AMP mixed-precision instead of FP16. "
+ "FP16 support needs further verification and tuning, especially for train."
+ )
+
+ if args.horovod:
+ logging.info(
+ f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ elif args.distributed:
+ logging.info(
+ f"Running in distributed mode with multiple processes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ else:
+ logging.info(f"Running with a single process. Device {args.device}.")
+
+ logging.info(f'openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}')
+
+ # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
+ args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
+ writer = None
+ if args.save_logs and args.tensorboard:
+ assert tensorboard is not None, "Please install tensorboard."
+ writer = tensorboard.SummaryWriter(args.tensorboard_path)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, "Please install wandb."
+ logging.debug("Starting wandb.")
+ # you will have to configure this for your project!
+ wandb.init(
+ project="clap",
+ notes=args.wandb_notes,
+ name=args.wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ logging.debug("Finished loading wandb.")
+
+ for idx, f in enumerate(pretrained_ckpts):
+ logging.info(f"pretrained on {f}")
+ args.pretrained = f
+ ckpt = torch.load(f, map_location='cpu')
+ pretrain_epoch = 0
+ if 'epoch' in ckpt:
+ pretrain_epoch = ckpt['epoch']
+ # train
+ best_metrics = lp_main(args, device, writer, pretrain_epoch, idx)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, "Please install wandb."
+ for name, val in best_metrics.items():
+ wandb.log({f"val/summary/{name}": val, "epoch": pretrain_epoch})
+
+ if args.wandb and is_master(args):
+ wandb.finish()
+
+def update_metric(best_metric, new_metric):
+ for key in new_metric:
+ if key not in best_metric:
+ best_metric[key] = new_metric[key]
+ else:
+ best_metric[key] = max(best_metric[key], new_metric[key])
+ return best_metric
+
+def lp_main(args, device, writer, pretrain_epoch, idx):
+
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ args.class_index_dict = load_class_label(args.class_label_path)
+
+
+ # Create CLAP model
+ clap_model, clap_model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ args.lp_out_ch = len(list(args.class_index_dict.keys()))
+ # Linear Probe
+ if idx == 0:
+ logging.info(f"linear probe using mlp: {args.lp_mlp}")
+ logging.info(f"linear probe using freeze: {args.lp_freeze}")
+ logging.info(f"linear probe act layer: {args.lp_act}")
+ logging.info(f"linear probe out ch: {args.lp_out_ch}")
+ logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}")
+ logging.info(f"linear probe loss func: {args.lp_loss}")
+ logging.info(f"linear probe lp_metrics: {args.lp_metrics}")
+
+ model = LinearProbe(
+ clap_model,
+ mlp=args.lp_mlp, freeze=args.lp_freeze,
+ in_ch=512, out_ch=args.lp_out_ch,
+ act=args.lp_act
+ ) # in_ch is fixed (i.e., 512)
+ model = model.to(device)
+
+ if args.horovod:
+ with torch.no_grad():
+ for param in model.parameters():
+ param.set_(param.contiguous())
+
+ if args.trace:
+ model = trace_model(model, batch_size=args.batch_size, device=device)
+
+ if is_master(args) and idx == 0:
+ logging.info("Linear Probe CLAP Model:")
+ logging.info(f"{str(clap_model)}")
+ logging.info("Params:")
+ params_file = os.path.join(args.logs, args.name, "params.txt")
+ with open(params_file, "w") as f:
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ logging.info(f" {name}: {val}")
+ f.write(f"{name}: {val}\n")
+
+
+ if args.distributed and not args.horovod:
+ if args.use_bn_sync:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ ddp_args = {}
+ if args.ddp_static_graph:
+ # this doesn't exist in older PyTorch, arg only added if enabled
+ ddp_args["static_graph"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[device], find_unused_parameters=True, **ddp_args
+ )
+
+ data = get_data(args, clap_model_cfg)
+ assert len(data), "At least one train or eval dataset must be specified."
+ if args.trace:
+ assert "train" not in data, "Cannot train with traced model"
+
+ optimizer, scheduler, text_freeze_parameters = config_lp_optimizer(model, data, args)
+
+ scaler = GradScaler() if args.precision == "amp" else None
+
+ # optionally resume from a checkpoint
+ start_epoch = 0
+ if args.resume is not None:
+ if os.path.isfile(args.resume):
+ checkpoint = torch.load(args.resume, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if not args.distributed and next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module.") :]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ if args.split_opt:
+ if optimizer is not None:
+ for k, o_ in optimizer.items():
+ o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if scaler is not None and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ logging.info(
+ f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ logging.info(
+ f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ else:
+ logging.info("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ if args.wandb and is_master(args):
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ if args.debug:
+ wandb.watch(model, log="all")
+ if idx == 0:
+ wandb.save(params_file)
+
+ best_metrics = {}
+
+ if "train" not in data:
+ metric = evaluate(model, data, start_epoch, args, writer, extra_suffix="_pe@" + str(pretrain_epoch))
+ if is_master(args):
+ best_metrics = update_metric(best_metrics, metric)
+ return
+ elif start_epoch == 0 and "val" in data and not args.no_eval:
+ metric = evaluate(model, data, 0, args, writer, extra_suffix="_pe@" + str(pretrain_epoch))
+ if is_master(args):
+ best_metrics = update_metric(best_metrics, metric)
+ if args.save_top_performance:
+ current_top_k_ckpt_metrics = {
+ i: 0 for i in range(args.save_top_performance)
+ } # initialize the top-k metric for ckpts to 0
+
+ for epoch in range(start_epoch, args.epochs):
+ # freeze the text param after (include) args.freeze_text_after, this is -1 by default
+ if epoch == args.freeze_text_after:
+ print("Text pretrained parameters are freezed since this epoch.")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ if is_master(args):
+ logging.info(f"Start epoch {epoch}")
+
+ train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer, extra_suffix="_pe@" + str(pretrain_epoch))
+ completed_epoch = epoch + 1
+
+ if any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) and not args.no_eval:
+ metric = evaluate(model, data, completed_epoch, args, writer, extra_suffix="_pe@" + str(pretrain_epoch))
+ if is_master(args):
+ best_metrics = update_metric(best_metrics, metric)
+ if args.save_top_performance:
+ top_k_dataset = args.top_k_checkpoint_select_dataset
+ top_k_metric = args.top_k_checkpoint_select_metric
+ filtered_metrics = [
+ v
+ for k, v in metric.items()
+ if top_k_metric in k and top_k_dataset in k
+ ] # check all R@10 metrics (all dataset) and use it to update the ckpt
+ # Saving checkpoints.
+ if args.save_logs:
+ opt_dict = {
+ k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
+ }
+ checkpoint_dict = {
+ "epoch": completed_epoch,
+ "pretrain_epoch": pretrain_epoch,
+ "name": args.name,
+ "state_dict": model.state_dict(),
+ }
+ checkpoint_dict.update(opt_dict)
+ if scaler is not None:
+ checkpoint_dict["scaler"] = scaler.state_dict()
+
+ if completed_epoch == args.epochs or (
+ args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
+ ):
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"pretrain_epoch_{pretrain_epoch}_lp_epoch_{completed_epoch}.pt"),
+ )
+ if args.save_most_recent:
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"pretrain_epoch_{pretrain_epoch}_lp_epoch_latest.pt"),
+ )
+ if args.save_top_performance and not args.no_eval:
+ update_top_k_performance(
+ filtered_metrics,
+ current_top_k_ckpt_metrics,
+ args,
+ checkpoint_dict,
+ bignumbetter=True,
+ pretrain_epoch=pretrain_epoch
+ )
+ del clap_model
+ return best_metrics
+
+
+def copy_codebase(args):
+ from shutil import copytree, ignore_patterns
+
+ new_code_path = os.path.join(args.logs, args.name, "code")
+ if os.path.exists(new_code_path):
+ print(
+ f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
+ )
+ return -1
+ print(f"Copying codebase to {new_code_path}")
+ current_code_path = os.path.realpath(__file__)
+ for _ in range(3):
+ current_code_path = os.path.dirname(current_code_path)
+ copytree(
+ current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
+ )
+ print("Done copying code.")
+ return 1
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/src/laion_clap/evaluate/eval_retrieval.py b/src/laion_clap/evaluate/eval_retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..739734cee63588ec647dcb9189520af6294764f2
--- /dev/null
+++ b/src/laion_clap/evaluate/eval_retrieval.py
@@ -0,0 +1,192 @@
+import os.path
+import glob
+import random
+import numpy as np
+import logging
+import wandb
+import torch
+import torch.backends.cudnn as cudnn
+from laion_clap import create_model
+from laion_clap.training.logger import setup_logging
+from laion_clap.training.data import get_data
+from laion_clap.training.train import evaluate
+from laion_clap.utils import get_tar_path_from_dataset_name, dataset_split
+from laion_clap.training.params import parse_args
+
+
+def find_params_value(file, key):
+ # find value of params in params_file
+ with open(file, 'r') as f:
+ for line in f:
+ if key + ': ' in line:
+ return line.split(': ')[1].strip()
+ return None
+
+
+if __name__ == '__main__':
+ # (yusong) repeated run might have different metric results.
+ # This is because we randomly select crop 10s for each audio.
+ args = parse_args()
+
+ if os.path.isdir(args.pretrained):
+ log_dir = os.path.dirname(args.pretrained)
+ else:
+ log_dir = os.path.dirname(os.path.dirname(args.pretrained))
+
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ log_path = os.path.join(log_dir, 'out.log')
+ setup_logging(log_path, args.log_level)
+ params_file = os.path.join(log_dir, 'params.txt')
+
+ seed = 3407
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+ pretrained = 'openai'
+ amodel = find_params_value(params_file, 'amodel')
+ tmodel = find_params_value(params_file, 'tmodel')
+
+ if amodel is None or tmodel is None:
+ raise ValueError('model type not found in params file')
+
+ # set up dummy values for args
+ args.parallel_eval = False
+ args.rank = 0
+ args.local_rank = 0
+ args.world_size = 1
+ args.val_frequency = 1
+ args.epochs = 1
+ args.precision = 'fp32'
+ args.save_logs = True
+ args.wandb = True
+ args.class_index_dict = None
+
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+ args.device = device
+
+ if args.remotedata:
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ if args.datasetinfos is None:
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+ if args.dataset_type == "webdataset":
+ args.train_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ args.datasetinfos,
+ islocal=not args.remotedata,
+ proportion=args.dataset_proportion,
+ dataset_path=args.datasetpath,
+ )
+ args.val_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ ["valid", "test", "eval"],
+ islocal=not args.remotedata,
+ proportion=1,
+ dataset_path=args.datasetpath,
+ )
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision='fp32',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ ) # a hack to get model_cfg
+
+ data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data
+
+ writer = None # if use tensorboard, initalize writer here
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+
+ # # find the line with "wandb_notes" and get the value
+ # wandb_notes = find_params_value(params_file, 'wandb_notes')
+ # if wandb_notes is None:
+ # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.')
+ # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}'
+ # wandb_notes = wandb_notes + '-eval-retrieval'
+ wandb_notes = args.wandb_notes
+
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ if args.wandb_id is not None:
+ wandb.init(
+ project="clap",
+ id=args.wandb_id,
+ resume=True
+ )
+ else:
+ wandb.init(
+ project="clap",
+ notes=wandb_notes,
+ name=wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ logging.debug("Finished loading wandb.")
+
+ if os.path.isdir(args.pretrained):
+ all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime)
+ else:
+ all_model_checkpoints = [args.pretrained]
+ for model_path in all_model_checkpoints:
+ args.checkpoint_path = os.path.dirname(model_path)
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision='fp32',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ # load model
+ checkpoint = torch.load(model_path, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module."):]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ logging.info(
+ f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ start_epoch = 0
+
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ evaluate(model, data, start_epoch, args, writer)
diff --git a/src/laion_clap/evaluate/eval_retrieval_main.py b/src/laion_clap/evaluate/eval_retrieval_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..edfa65fdbf19377c1a0ba5c2e8c4fdc6f0d64e96
--- /dev/null
+++ b/src/laion_clap/evaluate/eval_retrieval_main.py
@@ -0,0 +1,257 @@
+import os.path
+import glob
+import random
+import numpy as np
+import logging
+import wandb
+import torch
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from clap_module import create_model
+from clap_module import tokenize
+from training.logger import setup_logging
+from training.data import get_data
+from training.train import evaluate
+from clap_module.utils import get_tar_path_from_dataset_name, dataset_split
+from training.params import parse_args
+
+
+def find_params_value(file, key):
+ # find value of params in params_file
+ with open(file, 'r') as f:
+ for line in f:
+ if key + ': ' in line:
+ return line.split(': ')[1].strip()
+ return None
+
+
+def evaluate_zeroshot(model, data, start_epoch, args, writer):
+ dataloader = data["val"].dataloader
+ metrics = {}
+ device = torch.device(args.device)
+ model.eval()
+ metrics.update({"epoch": start_epoch})
+
+ all_audio_features = []
+ all_class_labels = []
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ audios = batch # contains mel_spec, wavform, and longer list
+ audio_features = model(audios, None, device)
+ audio_features = F.normalize(audio_features, dim=-1)
+ all_audio_features.append(audio_features.detach().cpu())
+ all_class_labels.append(torch.argmax(batch["class_label"], 1).long())
+ all_audio_features = torch.cat(all_audio_features, dim=0)
+ all_class_labels = torch.cat(all_class_labels, dim=0)
+ metrics["num_samples"] = all_audio_features.shape[0]
+
+ # get text features
+ all_texts = ["This is a sound of " + t for t in args.class_index_dict.keys()]
+ # (yusong): a hack, can make it better
+ if args.tmodel == "transformer":
+ from clap_module.tokenizer import tokenize
+ all_texts = tokenize(all_texts)
+ else:
+ from training.data import tokenizer
+ all_texts = tokenizer(all_texts)
+ all_text_features = model(None, all_texts, device)
+ all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu()
+
+ # compute similarity
+ logit_scale_a, logit_scale_t = model(None, None, device)
+ logit_scale_a = logit_scale_a.cpu()
+
+ logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu()
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ ground_truth = all_class_labels.view(-1, 1)
+ logit = logits_per_audio
+
+ ranking = torch.argsort(logit, descending=True)
+ preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread
+ preds = preds.detach().cpu().numpy()
+ metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1
+ metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k)
+ # map@10
+ metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
+
+ logging.info(
+ f"Eval Epoch: {start_epoch} "
+ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
+ )
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ for name, val in metrics.items():
+ wandb.log({f"val/{name}": val, "epoch": start_epoch})
+
+
+if __name__ == '__main__':
+ # (yusong) repeated run might have different metric results.
+ # This is because we randomly select crop 10s for each audio.
+ args = parse_args()
+
+ if os.path.isdir(args.pretrained):
+ log_dir = os.path.dirname(args.pretrained)
+ else:
+ log_dir = os.path.dirname(os.path.dirname(args.pretrained))
+
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ log_path = os.path.join(log_dir, 'out.log')
+ setup_logging(log_path, args.log_level)
+ params_file = os.path.join(log_dir, 'params.txt')
+
+ seed = 3407
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+ pretrained = 'openai'
+ amodel = find_params_value(params_file, 'amodel')
+ tmodel = find_params_value(params_file, 'tmodel')
+
+ if amodel is None or tmodel is None:
+ raise ValueError('model type not found in params file')
+
+ # set up dummy values for args
+ args.parallel_eval = False
+ args.rank = 0
+ args.local_rank = 0
+ args.world_size = 1
+ args.val_frequency = 1
+ args.epochs = 1
+ args.precision = 'fp32'
+ args.save_logs = True
+ args.wandb = args.report_to == 'wandb'
+ args.class_index_dict = None
+
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+ args.device = device
+
+ if args.remotedata:
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ if args.datasetinfos is None:
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+ if args.dataset_type == "webdataset":
+ args.train_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ args.datasetinfos,
+ islocal=not args.remotedata,
+ proportion=args.dataset_proportion,
+ dataset_path=args.datasetpath,
+ )
+ args.val_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ ["valid", "test", "eval"],
+ islocal=not args.remotedata,
+ proportion=1,
+ dataset_path=args.datasetpath,
+ )
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision='fp32',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ ) # a hack to get model_cfg
+
+ data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data
+
+ writer = None # if use tensorboard, initalize writer here
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+
+ # # find the line with "wandb_notes" and get the value
+ # wandb_notes = find_params_value(params_file, 'wandb_notes')
+ # if wandb_notes is None:
+ # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.')
+ # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}'
+ # wandb_notes = wandb_notes + '-eval-retrieval'
+ wandb_notes = args.wandb_notes
+
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ if args.wandb_id is not None:
+ wandb.init(
+ project="clap",
+ id=args.wandb_id,
+ resume=True
+ )
+ else:
+ wandb.init(
+ project="clap",
+ notes=wandb_notes,
+ name=wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ logging.debug("Finished loading wandb.")
+
+ if os.path.isdir(args.pretrained):
+ all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime)
+ else:
+ all_model_checkpoints = [args.pretrained]
+ for model_path in all_model_checkpoints:
+ args.checkpoint_path = os.path.dirname(model_path)
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision='fp32',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ # load model
+ checkpoint = torch.load(model_path, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module."):]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ logging.info(
+ f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ start_epoch = 0
+
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ evaluate_zeroshot(model, data, start_epoch, args, writer)
diff --git a/src/laion_clap/evaluate/eval_zeroshot_classification.py b/src/laion_clap/evaluate/eval_zeroshot_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..577cb91125bacb7f7dc7fb9841c2cd819478a736
--- /dev/null
+++ b/src/laion_clap/evaluate/eval_zeroshot_classification.py
@@ -0,0 +1,261 @@
+import os.path
+import glob
+import random
+import numpy as np
+import logging
+import wandb
+import torch
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+from clap_module import create_model
+from clap_module import tokenize
+from training.logger import setup_logging
+from training.data import get_data
+from training.train import evaluate
+from clap_module.utils import get_tar_path_from_dataset_name, dataset_split
+from training.params import parse_args
+
+
+def find_params_value(file, key):
+ # find value of params in params_file
+ with open(file, 'r') as f:
+ for line in f:
+ if key + ': ' in line:
+ return line.split(': ')[1].strip()
+ return None
+
+
+def evaluate_zeroshot(model, data, start_epoch, args, writer):
+ dataloader = data["val"].dataloader
+ metrics = {}
+ device = torch.device(args.device)
+ model.eval()
+ metrics.update({"epoch": start_epoch})
+
+ all_audio_features = []
+ all_class_labels = []
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ audios = batch # contains mel_spec, wavform, and longer list
+ audio_features = model(audios, None, device)
+ audio_features = F.normalize(audio_features, dim=-1)
+ all_audio_features.append(audio_features.detach().cpu())
+ all_class_labels.append(torch.argmax(batch["class_label"], 1).long())
+ all_audio_features = torch.cat(all_audio_features, dim=0)
+ all_class_labels = torch.cat(all_class_labels, dim=0)
+ metrics["num_samples"] = all_audio_features.shape[0]
+
+ # get text features
+ if args.val_dataset_names == ['GTZAN']:
+ all_texts = [f"This is a {t} song." for t in args.class_index_dict.keys()]
+ else:
+ all_texts = [f"This is a sound of {t}." for t in args.class_index_dict.keys()]
+ logging.info(f'class label prompts: {all_texts}')
+ # (yusong): a hack, can make it better
+ if args.tmodel == "transformer":
+ from clap_module.tokenizer import tokenize
+ all_texts = tokenize(all_texts)
+ else:
+ from training.data import tokenizer
+ all_texts = tokenizer(all_texts)
+ all_text_features = model(None, all_texts, device)
+ all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu()
+
+ # compute similarity
+ logit_scale_a, logit_scale_t = model(None, None, device)
+ logit_scale_a = logit_scale_a.cpu()
+
+ logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu()
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ ground_truth = all_class_labels.view(-1, 1)
+ logit = logits_per_audio
+
+ ranking = torch.argsort(logit, descending=True)
+ preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread
+ preds = preds.detach().cpu().numpy()
+ metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1
+ metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k)
+ # map@10
+ metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
+
+ logging.info(
+ f"Eval Epoch: {start_epoch} "
+ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
+ )
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ for name, val in metrics.items():
+ wandb.log({f"val/{name}": val, "epoch": start_epoch})
+
+
+if __name__ == '__main__':
+ # (yusong) repeated run might have different metric results.
+ # This is because we randomly select crop 10s for each audio.
+ args = parse_args()
+
+ if os.path.isdir(args.pretrained):
+ log_dir = os.path.dirname(args.pretrained)
+ else:
+ log_dir = os.path.dirname(os.path.dirname(args.pretrained))
+
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ log_path = os.path.join(log_dir, 'out.log')
+ setup_logging(log_path, args.log_level)
+ params_file = os.path.join(log_dir, 'params.txt')
+
+ seed = 3407
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+ pretrained = 'openai'
+ amodel = find_params_value(params_file, 'amodel')
+ tmodel = find_params_value(params_file, 'tmodel')
+
+ if amodel is None or tmodel is None:
+ raise ValueError('model type not found in params file')
+
+ # set up dummy values for args
+ args.parallel_eval = False
+ args.rank = 0
+ args.local_rank = 0
+ args.world_size = 1
+ args.val_frequency = 1
+ args.epochs = 1
+ args.precision = 'fp32'
+ args.save_logs = True
+ args.wandb = args.report_to == 'wandb'
+ args.class_index_dict = None
+
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+ args.device = device
+
+ if args.remotedata:
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ if args.datasetinfos is None:
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+ if args.dataset_type == "webdataset":
+ args.train_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ args.datasetinfos,
+ islocal=not args.remotedata,
+ proportion=args.dataset_proportion,
+ dataset_path=args.datasetpath,
+ )
+ args.val_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ ["valid", "test", "eval"],
+ islocal=not args.remotedata,
+ proportion=1,
+ dataset_path=args.datasetpath,
+ )
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision='fp32',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ ) # a hack to get model_cfg
+
+ data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data
+
+ writer = None # if use tensorboard, initalize writer here
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+
+ # # find the line with "wandb_notes" and get the value
+ # wandb_notes = find_params_value(params_file, 'wandb_notes')
+ # if wandb_notes is None:
+ # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.')
+ # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}'
+ # wandb_notes = wandb_notes + '-eval-retrieval'
+ wandb_notes = args.wandb_notes
+
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ if args.wandb_id is not None:
+ wandb.init(
+ project="clap",
+ id=args.wandb_id,
+ resume=True
+ )
+ else:
+ wandb.init(
+ project="clap",
+ notes=wandb_notes,
+ name=wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ logging.debug("Finished loading wandb.")
+
+ if os.path.isdir(args.pretrained):
+ all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime)
+ else:
+ all_model_checkpoints = [args.pretrained]
+ for model_path in all_model_checkpoints:
+ args.checkpoint_path = os.path.dirname(model_path)
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision='fp32',
+ device=device,
+ jit=False,
+ force_quick_gelu=False,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ # load model
+ checkpoint = torch.load(model_path, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module."):]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ logging.info(
+ f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ start_epoch = 0
+
+ model.to(device)
+ model.eval()
+ for param in model.parameters():
+ param.requires_grad = False
+
+ evaluate_zeroshot(model, data, start_epoch, args, writer)
diff --git a/src/laion_clap/hook.py b/src/laion_clap/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86e942eaf44dfbe0d598bf789e21950e9974778
--- /dev/null
+++ b/src/laion_clap/hook.py
@@ -0,0 +1,219 @@
+"""
+Contrastive Language-Audio Pretraining Model from LAION
+--------------------------------------------------------
+Paper: https://arxiv.org/abs/2211.06687
+Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui
+Support: LAION
+"""
+import os
+import torch
+import librosa
+from clap_module import create_model
+from training.data import get_audio_features
+from training.data import int16_to_float32, float32_to_int16
+
+from transformers import RobertaTokenizer
+import wget
+from clap_module.factory import load_state_dict
+
+
+class CLAP_Module(torch.nn.Module):
+ def __init__(self, enable_fusion=False, device=None, amodel='HTSAT-tiny', tmodel='roberta') -> None:
+ """Initialize CLAP Model
+
+ Parameters
+ ----------
+ enable_fusion: bool
+ if true, it will create the fusion clap model, otherwise non-fusion clap model (default: false)
+ device: str
+ if None, it will automatically detect the device (gpu or cpu)
+ amodel: str
+ audio encoder architecture, default: HTSAT-tiny
+ tmodel: str
+ text encoder architecture, default: roberta
+ """
+ super(CLAP_Module, self).__init__()
+ if device is None:
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+
+ precision = 'fp32'
+
+ if enable_fusion:
+ fusion_type = 'aff_2d'
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ precision=precision,
+ device=device,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type
+ )
+ else:
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ precision=precision,
+ device=device,
+ enable_fusion=enable_fusion
+ )
+ self.enable_fusion = enable_fusion
+ self.model = model
+ self.model_cfg = model_cfg
+ self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+
+ def tokenizer(self, text):
+ result = self.tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ return result
+
+ def load_ckpt(self, ckpt = None, model_id = -1, verbose = True):
+ """Load the pretrained checkpoint of CLAP model
+
+ Parameters
+ ----------
+ ckpt: str
+ if ckpt is specified, the model will load this ckpt, otherwise the model will download the ckpt from zenodo. \n
+ For fusion model, it will download the 630k+audioset fusion model (id=3). For non-fusion model, it will download the 630k+audioset model (id=1).
+ model_id:
+ if model_id is specified, you can download our best ckpt, as:
+ id = 0 --> 630k non-fusion ckpt \n
+ id = 1 --> 630k+audioset non-fusion ckpt \n
+ id = 2 --> 630k fusion ckpt \n
+ id = 3 --> 630k+audioset fusion ckpt \n
+ Note that if your model is specied as non-fusion model but you download a fusion model ckpt, you will face an error.
+ """
+ download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/'
+ download_names = [
+ '630k-best.pt',
+ '630k-audioset-best.pt',
+ '630k-fusion-best.pt',
+ '630k-audioset-fusion-best.pt'
+ ]
+ if ckpt is not None:
+ print(f'Load the specified checkpoint {ckpt} from users.')
+ else:
+ print(f'Load our best checkpoint in the paper.')
+ if model_id == -1:
+ model_id = 3 if self.enable_fusion else 1
+ package_dir = os.path.dirname(os.path.realpath(__file__))
+ weight_file_name = download_names[model_id]
+ ckpt = os.path.join(package_dir, weight_file_name)
+ if os.path.exists(ckpt):
+ print(f'The checkpoint is already downloaded')
+ else:
+ print('Downloading laion_clap weight files...')
+ ckpt = wget.download(download_link + weight_file_name, os.path.dirname(ckpt))
+ print('Download completed!')
+ print('Load Checkpoint...')
+ ckpt = load_state_dict(ckpt, skip_params=True)
+ self.model.load_state_dict(ckpt)
+ if verbose:
+ param_names = [n for n, p in self.model.named_parameters()]
+ for n in param_names:
+ print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
+
+ def get_audio_embedding_from_filelist(self, x, use_tensor=False):
+ """get audio embeddings from the audio file list
+
+ Parameters
+ ----------
+ x: List[str] (N,):
+ an audio file list to extract features, audio files can have different lengths (as we have the feature fusion machanism)
+ use_tensor: boolean:
+ if True, it will return the torch tensor, preserving the gradient (default: False).
+ Returns
+ ----------
+ audio_embed : numpy.darray | torch.Tensor (N,D):
+ audio embeddings that extracted from audio files
+ """
+ self.model.eval()
+ audio_input = []
+ for f in x:
+ # load the waveform of the shape (T,), should resample to 48000
+ audio_waveform, _ = librosa.load(f, sr=48000)
+ # quantize
+ audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
+ audio_waveform = torch.from_numpy(audio_waveform).float()
+ temp_dict = {}
+ temp_dict = get_audio_features(
+ temp_dict, audio_waveform, 480000,
+ data_truncating='fusion' if self.enable_fusion else 'rand_trunc',
+ data_filling='repeatpad',
+ audio_cfg=self.model_cfg['audio_cfg'],
+ require_grad=audio_waveform.requires_grad
+ )
+ audio_input.append(temp_dict)
+ audio_embed = self.model.get_audio_embedding(audio_input)
+ if not use_tensor:
+ audio_embed = audio_embed.detach().cpu().numpy()
+ return audio_embed
+
+
+ def get_audio_embedding_from_data(self, x, use_tensor=False):
+ """get audio embeddings from the audio data
+
+ Parameters
+ ----------
+ x: np.darray | torch.Tensor (N,T):
+ audio data, must be mono audio tracks.
+ use_tensor: boolean:
+ if True, x should be the tensor input and the output will be the tesnor, preserving the gradient (default: False).
+ Note that if 'use tensor' is set to True, it will not do the quantize of the audio waveform (otherwise the gradient will not be preserved).
+ Returns
+ ----------
+ audio embed: numpy.darray | torch.Tensor (N,D):
+ audio embeddings that extracted from audio files
+ """
+ self.model.eval()
+ audio_input = []
+ for audio_waveform in x:
+ # quantize
+ if not use_tensor:
+ audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
+ audio_waveform = torch.from_numpy(audio_waveform).float()
+ temp_dict = {}
+ temp_dict = get_audio_features(
+ temp_dict, audio_waveform, 480000,
+ data_truncating='fusion' if self.enable_fusion else 'rand_trunc',
+ data_filling='repeatpad',
+ audio_cfg=self.model_cfg['audio_cfg'],
+ require_grad=audio_waveform.requires_grad
+ )
+ audio_input.append(temp_dict)
+ audio_embed = self.model.get_audio_embedding(audio_input)
+ if not use_tensor:
+ audio_embed = audio_embed.detach().cpu().numpy()
+ return audio_embed
+
+ def get_text_embedding(self, x, tokenizer = None, use_tensor = False):
+ """get text embeddings from texts
+
+ Parameters
+ ----------
+ x: List[str] (N,):
+ text list
+ tokenizer: func:
+ the tokenizer function, if not provided (None), will use the default Roberta tokenizer.
+ use_tensor: boolean:
+ if True, the output will be the tesnor, preserving the gradient (default: False).
+ Returns
+ ----------
+ text_embed : numpy.darray | torch.Tensor (N,D):
+ text embeddings that extracted from texts
+ """
+ self.model.eval()
+ if tokenizer is not None:
+ text_input = tokenizer(x)
+ else:
+ text_input = self.tokenizer(x)
+ text_embed = self.model.get_text_embedding(text_input)
+ if not use_tensor:
+ text_embed = text_embed.detach().cpu().numpy()
+ return text_embed
+
+
diff --git a/src/laion_clap/inference.py b/src/laion_clap/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..759c8e32ab9cbd4925de0b0aaad10766f010da56
--- /dev/null
+++ b/src/laion_clap/inference.py
@@ -0,0 +1,41 @@
+import numpy as np
+import librosa
+import torch
+from src import laion_clap
+from glob import glob
+import pandas as pd
+from ..config.configs import ProjectPaths
+import pickle
+
+
+class AudioEncoder(laion_clap.CLAP_Module):
+ def __init__(self) -> None:
+ super().__init__(enable_fusion=False, amodel='HTSAT-base')
+ self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH)
+
+ def extract_audio_representaion(self, file_name):
+ audio_data, _ = librosa.load(file_name, sr=48000)
+ audio_data = audio_data.reshape(1, -1)
+ with torch.no_grad():
+ audio_embed = self.get_audio_embedding_from_data(x=audio_data, use_tensor=False)
+ return audio_embed
+
+ def extract_bulk_audio_representaions(self, save=False):
+ music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav")))
+ song_names = [k.split("/")[-1] for k in music_files]
+ music_data = np.zeros((len(music_files), 512), dtype=np.float32)
+ for m in range(music_data.shape[0]):
+ music_data[m] = self.extract_audio_representaion(music_files[m])
+
+ if not save:
+ return music_data, song_names
+
+ else:
+ np.save(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
+ with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl", "rb")) as writer:
+ pickle.dump(song_names, writer)
+
+ def extract_text_representation(self, text):
+ text_data = [text]
+ text_embed = self.get_text_embedding(text_data)
+ return text_embed
\ No newline at end of file
diff --git a/src/laion_clap/training/__init__.py b/src/laion_clap/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/laion_clap/training/audioset_textmap.npy b/src/laion_clap/training/audioset_textmap.npy
new file mode 100644
index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b
--- /dev/null
+++ b/src/laion_clap/training/audioset_textmap.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
+size 84448
diff --git a/src/laion_clap/training/data.py b/src/laion_clap/training/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad90621575f45388feb5c015cfff5d1f5fe2146
--- /dev/null
+++ b/src/laion_clap/training/data.py
@@ -0,0 +1,895 @@
+import ast
+import json
+import logging
+import math
+import os
+import random
+import h5py
+from dataclasses import dataclass
+import braceexpand
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+import torchvision.datasets as datasets
+import torchvision.transforms
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
+from torch.utils.data.distributed import DistributedSampler
+from functools import partial
+from pathlib import Path
+import wget
+import tempfile
+import copy
+from contextlib import suppress
+
+from clap_module.utils import get_tar_path_from_dataset_name, dataset_split
+from clap_module.utils import load_p, load_class_label
+from clap_module import tokenize as clip_tokenizer
+from transformers import BertTokenizer
+from transformers import RobertaTokenizer
+from transformers import BartTokenizer
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+try:
+ import torchaudio
+except ImportError:
+ torchaudio = None
+
+bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
+
+def tokenizer(text, tmodel="roberta", max_length=77):
+ """tokenizer for different models
+ tmodel is default to roberta as it is the best model for our task
+ max_length is default to 77 from the OpenAI CLIP parameters
+ We assume text to be a single string, but it can also be a list of strings
+ """
+ if tmodel == "transformer":
+ return clip_tokenizer(text).squeeze(0)
+
+ elif tmodel == "bert":
+ result = bert_tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+ elif tmodel == "roberta":
+ result = roberta_tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+ elif tmodel == "bart":
+ result = bart_tokenizer(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+
+# initizlied the audioset map
+_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
+_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
+
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1., a_max=1.)
+ return (x * 32767.).astype(np.int16)
+
+
+def int16_to_float32_torch(x):
+ return (x / 32767.0).type(torch.float32)
+
+
+def float32_to_int16_torch(x):
+ x = torch.clamp(x, min=-1., max=1.)
+ return (x * 32767.).type(torch.int16)
+
+
+# For Toy Dataset
+class ToyDataset(Dataset):
+ def __init__(self, index_path, ipc, config, eval_mode=False):
+ """Toy Dataset for testing the audioset input with text labels
+ Parameters
+ ----------
+ index_path: str
+ the link to the h5 file of each audio
+ idc: str
+ the link to the npy file, the number of samples in each class
+ config: dict
+ the audio cfg file
+ eval_model (bool): to indicate if the dataset is a testing dataset
+ """
+ self.audio_cfg = config["audio_cfg"]
+ self.text_cfg = config["text_cfg"]
+ self.fp = h5py.File(index_path, "r")
+ self.ipc = np.load(ipc, allow_pickle=True)
+ self.total_size = len(self.fp["audio_name"])
+ self.classes_num = self.audio_cfg["class_num"]
+ self.eval_mode = eval_mode
+
+ if not eval_mode:
+ self.generate_queue()
+ else:
+ self.queue = []
+ for i in range(self.total_size):
+ target = self.fp["target"][i]
+ if np.sum(target) > 0:
+ self.queue.append(i)
+ self.total_size = len(self.queue)
+ logging.info("total dataset size: %d" % (self.total_size))
+ logging.info("class num: %d" % (self.classes_num))
+
+ def time_shifting(self, x):
+ frame_num = len(x)
+ shift_len = random.randint(0, frame_num - 1)
+ new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
+ return new_sample
+
+ def generate_queue(self):
+ self.queue = []
+ while len(self.queue) < self.total_size:
+ class_set = [*range(self.classes_num)]
+ random.shuffle(class_set)
+ self.queue += [
+ self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
+ ]
+ self.queue = self.queue[: self.total_size]
+
+ logging.info("queue regenerated:%s" % (self.queue[-5:]))
+
+ def crop_wav(self, x):
+ crop_size = self.audio_cfg["crop_size"]
+ crop_pos = random.randint(0, len(x) - crop_size - 1)
+ return x[crop_pos: crop_pos + crop_size]
+
+ def prompt_text(self, target):
+ events = _AUDIOSET_MAP[np.where(target > 0)]
+ event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
+ text = tokenizer(event_text)[0]
+ return text
+
+ def __getitem__(self, index):
+ """Load waveform, text, and target of an audio clip
+
+ Parameters
+ ----------
+ index: int
+ the index number
+ Return
+ ------
+ output: dict {
+ "hdf5_path": str,
+ "index_in_hdf5": int,
+ "audio_name": str,
+ "waveform": list (audio_length,),
+ "target": list (class_num, ),
+ "text": torch.tensor (context_length,)
+ }
+ the output dictionary
+ """
+ s_index = self.queue[index]
+
+ audio_name = self.fp["audio_name"][s_index].decode()
+ # Hardcode here CHANGE
+ hdf5_path = (
+ self.fp["hdf5_path"][s_index]
+ .decode()
+ .replace(
+ "../workspace",
+ "/home/la/kechen/Research/ke_zsasp/workspace",
+ )
+ )
+ r_idx = self.fp["index_in_hdf5"][s_index]
+ target = self.fp["target"][s_index].astype(np.float32)
+ text = self.prompt_text(target)
+ with h5py.File(hdf5_path, "r") as f:
+ waveform = int16_to_float32(f["waveform"][r_idx])[
+ : self.audio_cfg["clip_samples"]
+ ]
+ assert (
+ len(waveform) == self.audio_cfg["clip_samples"]
+ ), "The sample length is not match"
+ # Time shift
+ # if (self.config.enable_time_shift) and (not self.eval_mode):
+ # waveform = self.time_shifting(waveform)
+ # # Label Enhance
+ # if (self.config.crop_size is not None) and (not self.eval_mode):
+ # waveform = self.crop_wav(waveform)
+ # # the label enhance rate is fixed 0.5
+ # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
+ # kidx = np.where(target)[0]
+ # for k in kidx:
+ # for add_key in self.class_map[k][1]:
+ # target[add_key] = 1.0
+ # if len(self.class_map[k][2]) > 0:
+ # add_key = random.choice(self.class_map[k][2])
+ # target[add_key] = 1.0
+
+ # missing the text input
+ mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
+ mel_spec = torch.cat([mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0).cpu().numpy()
+ longer = random.choice([True, False])
+ if longer == False:
+ mel_spec[1:, :, :] = 0.0
+ data_dict = {
+ "hdf5_path": hdf5_path,
+ "index_in_hdf5": r_idx,
+ "audio_name": audio_name,
+ "waveform": waveform,
+ "class_label": target,
+ "text": text,
+ "longer": longer,
+ "mel_fusion": mel_spec
+ }
+ return data_dict
+
+ def __len__(self):
+ return self.total_size
+
+@dataclass
+class DataInfo:
+ dataloader: DataLoader
+ sampler: DistributedSampler
+
+
+def get_dataset_size(shards, sizefilepath_=None, is_local=True):
+ if isinstance(shards, list):
+ size_list = []
+ for s in shards:
+ size_list.append(
+ get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
+ )
+ else:
+ if not is_local:
+ for n in dataset_split.keys():
+ if n in shards.split("/"):
+ break
+ for s in dataset_split[n]:
+ if s in shards.split("/"):
+ break
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ shards_list = list(braceexpand.braceexpand(shards))
+ dir_path = os.path.dirname(shards)
+ if sizefilepath_ is not None:
+ sizes = json.load(open(sizefilepath_, "r"))
+ total_size = sum(
+ [
+ int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
+ for shard in shards_list
+ ]
+ )
+ else:
+ sizes_filename = os.path.join(dir_path, "sizes.json")
+ len_filename = os.path.join(dir_path, "__len__")
+ if os.path.exists(sizes_filename):
+ sizes = json.load(open(sizes_filename, "r"))
+ total_size = sum(
+ [int(sizes[os.path.basename(shard)]) for shard in shards_list]
+ )
+ elif os.path.exists(len_filename):
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+ total_size = ast.literal_eval(open(len_filename, "r").read())
+ else:
+ raise Exception(
+ f"Cannot find sizes file for dataset {shards}. Please specify the path to the file."
+ )
+ # total_size = None # num samples undefined
+ # some common dataset sizes (at time of authors last download)
+ # cc3m-train: 2905954
+ # cc12m: 10968539
+ # LAION-400m: 407332084
+ num_shards = len(shards_list)
+ if isinstance(shards, list):
+ return sum(size_list), len(shards)
+ else:
+ return total_size, num_shards
+
+
+def count_samples(dataloader):
+ os.environ["WDS_EPOCH"] = "0"
+ n_elements, n_batches = 0, 0
+ for images, texts in dataloader:
+ n_batches += 1
+ n_elements += len(images)
+ assert len(images) == len(texts)
+ return n_elements, n_batches
+
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+ return True
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+def sample_prop(sizefile, inputs, proportion, is_local=True):
+ """
+ Sample a proportion of the data.
+ """
+ file_path_dict = {
+ os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
+ for i in range(len(inputs))
+ }
+ sampled_filepath_dict = {}
+ sampled_size_dict = {}
+ if not is_local:
+ if os.path.exists("sizes.json"):
+ os.remove("sizes.json")
+ wget.download(sizefile, "sizes.json")
+ sizefile = "sizes.json"
+ with open(sizefile, "r", encoding="UTF-8") as f:
+ load_dict = json.load(f)
+ L = int(len(file_path_dict) * proportion)
+ subkeys = random.sample(file_path_dict.keys(), L)
+ for k in subkeys:
+ sampled_size_dict[k] = load_dict[k]
+ sampled_filepath_dict[k] = file_path_dict[k]
+ return (
+ sum(sampled_size_dict.values()),
+ L,
+ [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
+ sampled_size_dict,
+ )
+
+
+def get_mel(audio_data, audio_cfg):
+ # mel shape: (n_mels, T)
+ mel_tf = torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_cfg['sample_rate'],
+ n_fft=audio_cfg['window_size'],
+ win_length=audio_cfg['window_size'],
+ hop_length=audio_cfg['hop_size'],
+ center=True,
+ pad_mode="reflect",
+ power=2.0,
+ norm=None,
+ onesided=True,
+ n_mels=audio_cfg['mel_bins'],
+ f_min=audio_cfg['fmin'],
+ f_max=audio_cfg['fmax']
+ ).to(audio_data.device)
+
+ mel = mel_tf(audio_data)
+ # Align to librosa:
+ # librosa_melspec = librosa.feature.melspectrogram(
+ # waveform,
+ # sr=audio_cfg['sample_rate'],
+ # n_fft=audio_cfg['window_size'],
+ # hop_length=audio_cfg['hop_size'],
+ # win_length=audio_cfg['window_size'],
+ # center=True,
+ # pad_mode="reflect",
+ # power=2.0,
+ # n_mels=audio_cfg['mel_bins'],
+ # norm=None,
+ # htk=True,
+ # f_min=audio_cfg['fmin'],
+ # f_max=audio_cfg['fmax']
+ # )
+ # we use log mel spectrogram as input
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
+ return mel.T # (T, n_mels)
+
+
+def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg, require_grad=False):
+ """
+ Calculate and add audio features to sample.
+ Sample: a dict containing all the data of current sample.
+ audio_data: a tensor of shape (T) containing audio data.
+ max_len: the maximum length of audio data.
+ data_truncating: the method of truncating data.
+ data_filling: the method of filling data.
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
+ require_grad: whether to require gradient for audio data.
+ This is useful when we want to apply gradient-based classifier-guidance.
+ """
+ grad_fn = suppress if require_grad else torch.no_grad
+ with grad_fn():
+ if len(audio_data) > max_len:
+ if data_truncating == "rand_trunc":
+ longer = torch.tensor([True])
+ elif data_truncating == "fusion":
+ # fusion
+ mel = get_mel(audio_data, audio_cfg)
+ # split to three parts
+ chunk_frames = max_len // audio_cfg['hop_size'] + 1 # the +1 related to how the spectrogram is computed
+ total_frames = mel.shape[0]
+ if chunk_frames == total_frames:
+ # there is a corner case where the audio length is
+ # larger than max_len but smaller than max_len+hop_size.
+ # In this case, we just use the whole audio.
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
+ sample["mel_fusion"] = mel_fusion
+ longer = torch.tensor([False])
+ else:
+ ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
+ # print('total_frames-chunk_frames:', total_frames-chunk_frames,
+ # 'len(audio_data):', len(audio_data),
+ # 'chunk_frames:', chunk_frames,
+ # 'total_frames:', total_frames)
+ if len(ranges[1]) == 0:
+ # if the audio is too short, we just use the first chunk
+ ranges[1] = [0]
+ if len(ranges[2]) == 0:
+ # if the audio is too short, we just use the first chunk
+ ranges[2] = [0]
+ # randomly choose index for each part
+ idx_front = np.random.choice(ranges[0])
+ idx_middle = np.random.choice(ranges[1])
+ idx_back = np.random.choice(ranges[2])
+ # select mel
+ mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :]
+ mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :]
+ mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :]
+
+ # shrink the mel
+ mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, audio_cfg['mel_bins']])(mel[None])[0]
+ # logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
+
+ # stack
+ mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0)
+ sample["mel_fusion"] = mel_fusion
+ longer = torch.tensor([True])
+ else:
+ raise NotImplementedError(
+ f"data_truncating {data_truncating} not implemented"
+ )
+ # random crop to max_len (for compatibility)
+ overflow = len(audio_data) - max_len
+ idx = np.random.randint(0, overflow + 1)
+ audio_data = audio_data[idx: idx + max_len]
+
+ else: # padding if too short
+ if len(audio_data) < max_len: # do nothing if equal
+ if data_filling == "repeatpad":
+ n_repeat = int(max_len / len(audio_data))
+ audio_data = audio_data.repeat(n_repeat)
+ # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+ # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
+ audio_data = F.pad(
+ audio_data,
+ (0, max_len - len(audio_data)),
+ mode="constant",
+ value=0,
+ )
+ elif data_filling == "pad":
+ audio_data = F.pad(
+ audio_data,
+ (0, max_len - len(audio_data)),
+ mode="constant",
+ value=0,
+ )
+ elif data_filling == "repeat":
+ n_repeat = int(max_len / len(audio_data))
+ audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
+ else:
+ raise NotImplementedError(
+ f"data_filling {data_filling} not implemented"
+ )
+ if data_truncating == 'fusion':
+ mel = get_mel(audio_data, audio_cfg)
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
+ sample["mel_fusion"] = mel_fusion
+ longer = torch.tensor([False])
+
+ sample["longer"] = longer
+ sample["waveform"] = audio_data
+
+ return sample
+
+
+def select_text(json_dict_raw, text_augment_selection):
+ # For selecting augmented text from dataset
+ if text_augment_selection is None or text_augment_selection == "none":
+ texts = json_dict_raw["text"]
+ elif text_augment_selection == "all":
+ if "text_augment_all" in json_dict_raw.keys():
+ texts = json_dict_raw["text_augment_all"]
+ else:
+ texts = json_dict_raw["text"]
+ elif text_augment_selection == "augment_only":
+ if "text_augment_all" in json_dict_raw.keys():
+ if json_dict_raw["text_augment_t5"] is None:
+ texts = json_dict_raw["text"]
+ else:
+ texts = json_dict_raw["text_augment_t5"]
+ else:
+ texts = json_dict_raw["text"]
+ else:
+ raise NotImplementedError(
+ f"text_augment_selection {text_augment_selection} not implemented"
+ )
+ return texts
+
+
+def preprocess_single(
+ sample,
+ audio_ext,
+ text_ext,
+ max_len,
+ audio_cfg,
+ tmodel,
+ class_index_dict,
+ data_filling,
+ data_truncating,
+ text_augment_selection,
+):
+ """
+ Preprocess a single sample for wdsdataloader.
+ """
+ audio_data, orig_sr = sample[audio_ext]
+ audio_data = int16_to_float32_torch(float32_to_int16_torch(audio_data[0]))
+
+ sample = get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg)
+ del sample[audio_ext]
+
+ json_dict_raw = sample[text_ext]
+
+ texts = select_text(json_dict_raw, text_augment_selection)
+ sample["full_text"] = texts
+
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
+ texts = random.choice(texts)
+ sample["raw_text"] = texts
+ sample["text"] = tokenizer(texts, tmodel=tmodel) # text shape: [num_token]
+ if class_index_dict is not None:
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+
+ # in case the re-written version is wrong, here is the old version:
+ # sample["class_label"] = np.zeros(len(class_index_dict.keys()))
+ # for x in json_dict_raw["tag"]:
+ # sample["class_label"][class_index_dict[x]] = 1
+ # sample["class_label"] = torch.tensor(sample["class_label"]).float()
+
+ class_labels = np.zeros(len(class_index_dict))
+ class_labels[np.in1d(list(class_index_dict.keys()), json_dict_raw["tag"])] = 1
+ sample["class_label"] = torch.tensor(class_labels).float()
+
+ del sample[text_ext]
+ sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
+ sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
+ sample["audio_orig_sr"] = orig_sr
+ return sample
+
+
+def collate_fn_with_preprocess(batch,
+ audio_ext,
+ text_ext,
+ max_len,
+ audio_cfg,
+ args,
+ ):
+ """
+ Collate function for wdsdataloader.
+ batch: a list of dict, each dict is a sample
+ """
+
+ class_index_dict = copy.deepcopy(args.class_index_dict) # To avoid deadlock in multiprocessing
+ data_filling = args.data_filling
+ data_truncating = args.data_truncating
+ text_augment_selection = args.text_augment_selection
+ tmodel = args.tmodel
+
+ # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
+ data_preprocessed = []
+
+ for sample in batch:
+ data_preprocessed.append(
+ preprocess_single(sample, audio_ext, text_ext, max_len, audio_cfg, tmodel, class_index_dict, data_filling,
+ data_truncating, text_augment_selection))
+
+ batch_dict = {}
+ for k in data_preprocessed[0].keys():
+ if isinstance(data_preprocessed[0][k], dict): # dealwith bert tokenizer output
+ batch_dict[k] = {}
+ for kk in data_preprocessed[0][k].keys():
+ tmp = []
+ for i in range(len(data_preprocessed)):
+ tmp.append(data_preprocessed[i][k][kk])
+ batch_dict[k][kk] = torch.vstack(tmp)
+ elif isinstance(data_preprocessed[0][k], torch.Tensor):
+ batch_dict[k] = torch.stack([sample[k] for sample in data_preprocessed])
+ elif isinstance(data_preprocessed[0][k], np.ndarray):
+ batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in data_preprocessed]))
+ else:
+ batch_dict[k] = [sample[k] for sample in data_preprocessed]
+ del data_preprocessed
+ return batch_dict
+
+
+def get_wds_dataset(
+ args,
+ model_cfg,
+ is_train,
+ audio_ext="flac",
+ text_ext="json",
+ max_len=480000,
+ proportion=1.0,
+ sizefilepath_=None,
+ is_local=None,
+):
+ """
+ Get a dataset for wdsdataloader.
+ """
+ if is_local is None and (not args.remotedata is None):
+ is_local = not args.remotedata
+
+ input_shards = args.train_data if is_train else args.val_data
+ assert input_shards is not None
+
+ if not sizefilepath_ is None:
+ sizefilepath = sizefilepath_
+ else:
+ sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
+
+ if proportion != 1.0:
+ num_samples, num_shards, input_shards, _ = sample_prop(
+ sizefilepath, input_shards, proportion, is_local=is_local
+ )
+ else:
+ num_samples, num_shards = get_dataset_size(
+ input_shards, sizefilepath_=sizefilepath_, is_local=is_local
+ )
+
+ if not num_samples:
+ if is_train:
+ num_samples = args.train_num_samples
+ if not num_samples:
+ raise RuntimeError(
+ "Currently, number of dataset samples must be specified for training dataset. "
+ "Please specify via `--train-num-samples` if no dataset length info present."
+ )
+ else:
+ num_samples = (
+ args.val_num_samples or 0
+ ) # eval will just exhaust the iterator if not specified
+
+ pipeline = [wds.SimpleShardList(input_shards)]
+ # at this point we have an iterator over all the shards
+ # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
+ if is_train or args.parallel_eval:
+ pipeline.extend(
+ [
+ wds.detshuffle(
+ bufsize=_SHARD_SHUFFLE_SIZE,
+ initial=_SHARD_SHUFFLE_INITIAL,
+ seed=args.seed,
+ ),
+ wds.split_by_node,
+ wds.split_by_worker,
+ # at this point, we have an iterator over the shards assigned to each worker at each node
+ wds.tarfile_to_samples(handler=log_and_continue),
+ wds.shuffle(
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
+ initial=_SAMPLE_SHUFFLE_INITIAL,
+ rng=random.Random(args.seed),
+ ),
+ # wds.repeatedly, # FIXME determine if this is beneficial
+ ]
+ )
+ else:
+ pipeline.extend(
+ [
+ wds.split_by_worker,
+ # at this point, we have an iterator over the shards assigned to each worker
+ wds.tarfile_to_samples(handler=log_and_continue),
+ ]
+ )
+
+ pipeline.append(
+ wds.decode(wds.torch_audio),
+ )
+
+ pipeline.append(
+ wds.batched(
+ args.batch_size,
+ partial=not (is_train or args.parallel_eval),
+ collation_fn=partial(collate_fn_with_preprocess,
+ audio_ext=audio_ext,
+ text_ext=text_ext,
+ max_len=max_len,
+ audio_cfg=model_cfg['audio_cfg'],
+ args=args,
+ ),
+
+ )
+ )
+
+ dataset = wds.DataPipeline(*pipeline)
+ if is_train or args.parallel_eval:
+ # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
+ # (yusong): See comments below.
+ # roll over and repeat a few samples to get same number of full batches on each node
+ global_batch_size = args.batch_size * args.world_size
+ num_batches = math.ceil(num_samples / global_batch_size)
+ num_workers = max(1, args.workers)
+ num_worker_batches = math.ceil(
+ num_batches / num_workers
+ ) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+ dataset = dataset.with_epoch(
+ num_worker_batches
+ ) # each worker is iterating over this
+ else:
+ # last batches are partial, eval is done on single (master) node
+ num_batches = math.ceil(num_samples / args.batch_size)
+
+ kwargs = {}
+ if args.horovod: # multi-node training on summit
+ kwargs["multiprocessing_context"] = "forkserver"
+
+ if is_train:
+ if args.prefetch_factor:
+ prefetch_factor = args.prefetch_factor
+ else:
+ prefetch_factor = max(2, args.batch_size // args.workers)
+ else:
+ prefetch_factor = 2
+
+ dataloader = wds.WebLoader(
+ dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=args.workers,
+ pin_memory=True,
+ prefetch_factor=prefetch_factor,
+ **kwargs
+ )
+
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+ # if is_train:
+ # # roll over and repeat a few samples to get same number of full batches on each node
+ # global_batch_size = args.batch_size * args.world_size
+ # num_batches = math.ceil(num_samples / global_batch_size)
+ # num_workers = max(1, args.workers)
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
+ # num_samples = num_batches * global_batch_size
+ # dataloader = dataloader.with_epoch(num_batches)
+ # else:
+ # # last batches are partial, eval is done on single (master) node
+ # num_batches = math.ceil(num_samples / args.batch_size)
+
+ # add meta-data to dataloader instance for convenience
+ dataloader.num_batches = num_batches
+ dataloader.num_samples = num_samples
+
+ return DataInfo(dataloader, None)
+
+
+def wds_batch_list2dict(
+ batch,
+ keys=[
+ "__url__",
+ "__key__",
+ "waveform",
+ "text",
+ "raw_text",
+ "audio_name",
+ "text_name",
+ "audio_orig_sr",
+ ],
+):
+ """
+ Return a dictionary of the batch, with keys as the names of the fields.
+ """
+ assert len(keys) == len(
+ batch
+ ), "batch must have same number of keys as keys argument"
+ return {keys[i]: batch[i] for i in range(len(batch))}
+
+
+
+def get_toy_dataset(args, model_cfg, is_train):
+ index_path = args.train_data if is_train else args.val_data
+ ipc_path = args.train_ipc if is_train else args.val_ipc
+ assert index_path and ipc_path
+ eval_mode = not is_train
+ dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
+
+ num_samples = len(dataset)
+ sampler = (
+ DistributedSampler(dataset, shuffle=False)
+ if args.distributed and is_train
+ else None
+ )
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.workers,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+def get_dataset_fn(dataset_type):
+ if dataset_type == "webdataset":
+ return get_wds_dataset
+ elif dataset_type == "toy":
+ return get_toy_dataset
+ else:
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, model_cfg):
+ data = {}
+
+ args.class_index_dict = load_class_label(args.class_label_path)
+
+ if args.datasetinfos is None:
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+ if args.dataset_type == "webdataset":
+ args.train_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ args.datasetinfos,
+ islocal=not args.remotedata,
+ proportion=args.dataset_proportion,
+ dataset_path=args.datasetpath,
+ full_dataset=args.full_train_dataset,
+ )
+
+ if args.full_train_dataset is None:
+ args.full_train_dataset = []
+ if args.exclude_eval_dataset is None:
+ args.exclude_eval_dataset = []
+ excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
+
+ val_dataset_names = [n for n in args.datasetnames if n not in excluded_eval_datasets] \
+ if excluded_eval_datasets else args.datasetnames
+ args.val_dataset_names = val_dataset_names
+ args.val_data = get_tar_path_from_dataset_name(
+ val_dataset_names,
+ ["valid", "test", "eval"],
+ islocal=not args.remotedata,
+ proportion=1,
+ dataset_path=args.datasetpath,
+ full_dataset=None,
+ )
+
+ if args.train_data:
+ data["train"] = get_dataset_fn(args.dataset_type)(
+ args, model_cfg, is_train=True
+ )
+
+ if args.val_data:
+ data["val"] = get_dataset_fn(args.dataset_type)(
+ args, model_cfg, is_train=False
+ )
+
+ return data
diff --git a/src/laion_clap/training/distributed.py b/src/laion_clap/training/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..adb0e927a64dbe7fc83fecf65be054ac6bd28a94
--- /dev/null
+++ b/src/laion_clap/training/distributed.py
@@ -0,0 +1,139 @@
+import os
+
+import torch
+import socket
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def is_global_master(args):
+ return args.rank == 0
+
+
+def is_local_master(args):
+ return args.local_rank == 0
+
+
+def is_master(args, local=False):
+ return is_local_master(args) if local else is_global_master(args)
+
+
+def is_using_horovod():
+ # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
+ # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
+ ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
+ pmi_vars = ["PMI_RANK", "PMI_SIZE"]
+ if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
+ return True
+ else:
+ return False
+
+
+def is_using_distributed():
+ if 'WORLD_SIZE' in os.environ:
+ return int(os.environ['WORLD_SIZE']) > 1
+ if 'SLURM_NTASKS' in os.environ:
+ return int(os.environ['SLURM_NTASKS']) > 1
+ return False
+
+
+def world_info_from_env():
+ local_rank = 0
+ for v in ('SLURM_LOCALID', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'LOCAL_RANK'):
+ if v in os.environ:
+ local_rank = int(os.environ[v])
+ break
+ global_rank = 0
+ for v in ('SLURM_PROCID', 'PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'RANK'):
+ if v in os.environ:
+ global_rank = int(os.environ[v])
+ break
+ world_size = 1
+ for v in ('SLURM_NTASKS', 'PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'WORLD_SIZE'):
+ if v in os.environ:
+ world_size = int(os.environ[v])
+ break
+
+ return local_rank, global_rank, world_size
+
+
+def init_distributed_device(args):
+ # Distributed training = training on more than one GPU.
+ # Works in both single and multi-node scenarios.
+ args.distributed = False
+ args.world_size = 1
+ args.rank = 0 # global rank
+ args.local_rank = 0
+ if args.horovod:
+ assert hvd is not None, "Horovod is not installed"
+ hvd.init()
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.local_rank = local_rank
+ args.rank = world_rank
+ args.world_size = world_size
+ # args.local_rank = int(hvd.local_rank())
+ # args.rank = hvd.rank()
+ # args.world_size = hvd.size()
+ args.distributed = True
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ print(f"Distributed training: local_rank={args.local_rank}, "
+ f"rank={args.rank}, world_size={args.world_size}, "
+ f"hostname={socket.gethostname()}, pid={os.getpid()}")
+ elif is_using_distributed():
+ if 'SLURM_PROCID' in os.environ:
+ # DDP via SLURM
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+ # SLURM var -> torch.distributed vars in case needed
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ elif 'OMPI_COMM_WORLD_SIZE' in os.environ: # using Summit cluster
+ world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.local_rank = local_rank
+ args.rank = world_rank
+ args.world_size = world_size
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ else:
+ # DDP via torchrun, torch.distributed.launch
+ args.local_rank, _, _ = world_info_from_env()
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url)
+ args.world_size = torch.distributed.get_world_size()
+ args.rank = torch.distributed.get_rank()
+ args.distributed = True
+ print(f"Distributed training: local_rank={args.local_rank}, "
+ f"rank={args.rank}, world_size={args.world_size}, "
+ f"hostname={socket.gethostname()}, pid={os.getpid()}")
+
+ if torch.cuda.is_available():
+ if args.distributed and not args.no_set_device_rank:
+ device = 'cuda:%d' % args.local_rank
+ else:
+ device = 'cuda:0'
+ torch.cuda.set_device(device)
+ else:
+ device = 'cpu'
+ args.device = device
+ device = torch.device(device)
+ return device
diff --git a/src/laion_clap/training/imagenet_zeroshot_data.py b/src/laion_clap/training/imagenet_zeroshot_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78987448805afc228b2941302a2894818cac497
--- /dev/null
+++ b/src/laion_clap/training/imagenet_zeroshot_data.py
@@ -0,0 +1,254 @@
+# NOTE: This script is currently not supported for CLAP.
+
+imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
+ "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
+ "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
+ "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
+ "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
+ "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
+ "box turtle", "banded gecko", "green iguana", "Carolina anole",
+ "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
+ "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
+ "American alligator", "triceratops", "worm snake", "ring-necked snake",
+ "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
+ "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
+ "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
+ "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
+ "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
+ "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
+ "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
+ "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
+ "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
+ "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
+ "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
+ "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
+ "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
+ "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
+ "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
+ "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
+ "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
+ "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
+ "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
+ "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
+ "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
+ "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
+ "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
+ "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
+ "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
+ "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
+ "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
+ "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
+ "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
+ "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
+ "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
+ "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
+ "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
+ "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
+ "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
+ "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
+ "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
+ "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
+ "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
+ "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
+ "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
+ "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
+ "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
+ "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
+ "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
+ "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
+ "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
+ "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
+ "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
+ "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
+ "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
+ "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
+ "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
+ "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
+ "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
+ "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
+ "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
+ "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
+ "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
+ "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
+ "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
+ "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
+ "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
+ "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
+ "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
+ "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
+ "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
+ "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
+ "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
+ "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
+ "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
+ "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
+ "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
+ "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
+ "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
+ "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
+ "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
+ "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
+ "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
+ "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
+ "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
+ "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
+ "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
+ "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
+ "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
+ "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
+ "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
+ "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
+ "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
+ "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
+ "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
+ "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
+ "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
+ "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
+ "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
+ "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
+ "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
+ "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
+ "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
+ "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
+ "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
+ "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
+ "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
+ "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
+ "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
+ "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
+ "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
+ "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
+ "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
+ "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
+ "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
+ "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
+ "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
+ "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
+ "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
+ "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
+ "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
+ "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
+ "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
+ "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
+ "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
+ "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
+ "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
+ "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
+ "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
+ "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
+ "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
+ "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
+ "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
+ "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
+ "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
+ "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
+ "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
+ "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
+ "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
+ "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
+ "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
+ "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
+ "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
+ "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
+ "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
+ "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
+ "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
+ "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
+ "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
+ "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
+ "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
+ "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
+ "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
+ "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
+ "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
+ "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
+ "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
+ "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
+ "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]
+
+
+
+
+
+openai_imagenet_template = [
+ lambda c: f'a bad photo of a {c}.',
+ lambda c: f'a photo of many {c}.',
+ lambda c: f'a sculpture of a {c}.',
+ lambda c: f'a photo of the hard to see {c}.',
+ lambda c: f'a low resolution photo of the {c}.',
+ lambda c: f'a rendering of a {c}.',
+ lambda c: f'graffiti of a {c}.',
+ lambda c: f'a bad photo of the {c}.',
+ lambda c: f'a cropped photo of the {c}.',
+ lambda c: f'a tattoo of a {c}.',
+ lambda c: f'the embroidered {c}.',
+ lambda c: f'a photo of a hard to see {c}.',
+ lambda c: f'a bright photo of a {c}.',
+ lambda c: f'a photo of a clean {c}.',
+ lambda c: f'a photo of a dirty {c}.',
+ lambda c: f'a dark photo of the {c}.',
+ lambda c: f'a drawing of a {c}.',
+ lambda c: f'a photo of my {c}.',
+ lambda c: f'the plastic {c}.',
+ lambda c: f'a photo of the cool {c}.',
+ lambda c: f'a close-up photo of a {c}.',
+ lambda c: f'a black and white photo of the {c}.',
+ lambda c: f'a painting of the {c}.',
+ lambda c: f'a painting of a {c}.',
+ lambda c: f'a pixelated photo of the {c}.',
+ lambda c: f'a sculpture of the {c}.',
+ lambda c: f'a bright photo of the {c}.',
+ lambda c: f'a cropped photo of a {c}.',
+ lambda c: f'a plastic {c}.',
+ lambda c: f'a photo of the dirty {c}.',
+ lambda c: f'a jpeg corrupted photo of a {c}.',
+ lambda c: f'a blurry photo of the {c}.',
+ lambda c: f'a photo of the {c}.',
+ lambda c: f'a good photo of the {c}.',
+ lambda c: f'a rendering of the {c}.',
+ lambda c: f'a {c} in a video game.',
+ lambda c: f'a photo of one {c}.',
+ lambda c: f'a doodle of a {c}.',
+ lambda c: f'a close-up photo of the {c}.',
+ lambda c: f'a photo of a {c}.',
+ lambda c: f'the origami {c}.',
+ lambda c: f'the {c} in a video game.',
+ lambda c: f'a sketch of a {c}.',
+ lambda c: f'a doodle of the {c}.',
+ lambda c: f'a origami {c}.',
+ lambda c: f'a low resolution photo of a {c}.',
+ lambda c: f'the toy {c}.',
+ lambda c: f'a rendition of the {c}.',
+ lambda c: f'a photo of the clean {c}.',
+ lambda c: f'a photo of a large {c}.',
+ lambda c: f'a rendition of a {c}.',
+ lambda c: f'a photo of a nice {c}.',
+ lambda c: f'a photo of a weird {c}.',
+ lambda c: f'a blurry photo of a {c}.',
+ lambda c: f'a cartoon {c}.',
+ lambda c: f'art of a {c}.',
+ lambda c: f'a sketch of the {c}.',
+ lambda c: f'a embroidered {c}.',
+ lambda c: f'a pixelated photo of a {c}.',
+ lambda c: f'itap of the {c}.',
+ lambda c: f'a jpeg corrupted photo of the {c}.',
+ lambda c: f'a good photo of a {c}.',
+ lambda c: f'a plushie {c}.',
+ lambda c: f'a photo of the nice {c}.',
+ lambda c: f'a photo of the small {c}.',
+ lambda c: f'a photo of the weird {c}.',
+ lambda c: f'the cartoon {c}.',
+ lambda c: f'art of the {c}.',
+ lambda c: f'a drawing of the {c}.',
+ lambda c: f'a photo of the large {c}.',
+ lambda c: f'a black and white photo of a {c}.',
+ lambda c: f'the plushie {c}.',
+ lambda c: f'a dark photo of a {c}.',
+ lambda c: f'itap of a {c}.',
+ lambda c: f'graffiti of the {c}.',
+ lambda c: f'a toy {c}.',
+ lambda c: f'itap of my {c}.',
+ lambda c: f'a photo of a cool {c}.',
+ lambda c: f'a photo of a small {c}.',
+ lambda c: f'a tattoo of the {c}.',
+]
diff --git a/src/laion_clap/training/infer_demo.py b/src/laion_clap/training/infer_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b39f3f43fc9f295b244bc1ca2b444f4969488cba
--- /dev/null
+++ b/src/laion_clap/training/infer_demo.py
@@ -0,0 +1,92 @@
+import torch
+import librosa
+from clap_module import create_model
+from training.data import get_audio_features
+from training.data import int16_to_float32, float32_to_int16
+from transformers import RobertaTokenizer
+
+tokenize = RobertaTokenizer.from_pretrained('roberta-base')
+def tokenizer(text):
+ result = tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+def infer_text():
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+ precision = 'fp32'
+ amodel = 'HTSAT-tiny' # or 'PANN-14'
+ tmodel = 'roberta' # the best text encoder in our training
+ enable_fusion = True # False if you do not want to use the fusion model
+ fusion_type = 'aff_2d'
+ pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded.
+
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision=precision,
+ device=device,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type
+ )
+ # load the text, can be a list (i.e. batch size)
+ text_data = ["I love the contrastive learning", "I love the pretrain model"]
+ # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90
+ text_data = tokenizer(text_data)
+ model.eval()
+ text_embed = model.get_text_embedding(text_data)
+ text_embed = text_embed.detach().cpu().numpy()
+ print(text_embed)
+ print(text_embed.shape)
+
+def infer_audio():
+
+ device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+ precision = 'fp32'
+ amodel = 'HTSAT-tiny' # or 'PANN-14'
+ tmodel = 'roberta' # the best text encoder in our training
+ enable_fusion = True # False if you do not want to use the fusion model
+ fusion_type = 'aff_2d'
+ pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded.
+
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision=precision,
+ device=device,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type
+ )
+
+ # load the waveform of the shape (T,), should resample to 48000
+ audio_waveform, sr = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000)
+ # quantize
+ audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
+ audio_waveform = torch.from_numpy(audio_waveform).float()
+ audio_dict = {}
+
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+ audio_dict = get_audio_features(
+ audio_dict, audio_waveform, 480000,
+ data_truncating='fusion',
+ data_filling='repeatpad',
+ audio_cfg=model_cfg['audio_cfg']
+ )
+ model.eval()
+ # can send a list to the model, to process many audio tracks in one time (i.e. batch size)
+ audio_embed = model.get_audio_embedding([audio_dict])
+ audio_embed = audio_embed.detach().cpu().numpy()
+ print(audio_embed)
+ print(audio_embed.shape)
+
+
+
+if __name__ == "__main__":
+ infer_text()
+ # infer_audio()
diff --git a/src/laion_clap/training/logger.py b/src/laion_clap/training/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9abed92568d459cbc8d6094ae3901935d89621
--- /dev/null
+++ b/src/laion_clap/training/logger.py
@@ -0,0 +1,26 @@
+import logging
+
+
+def setup_logging(log_file, level, include_host=False):
+ if include_host:
+ import socket
+ hostname = socket.gethostname()
+ formatter = logging.Formatter(
+ f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
+ else:
+ formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
+
+ logging.root.setLevel(level)
+ loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
+ for logger in loggers:
+ logger.setLevel(level)
+
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(formatter)
+ logging.root.addHandler(stream_handler)
+
+ if log_file:
+ file_handler = logging.FileHandler(filename=log_file)
+ file_handler.setFormatter(formatter)
+ logging.root.addHandler(file_handler)
+
diff --git a/src/laion_clap/training/lp_main.py b/src/laion_clap/training/lp_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e88f7356950ef03c78ca4d88681eb78ff1b4f6a
--- /dev/null
+++ b/src/laion_clap/training/lp_main.py
@@ -0,0 +1,643 @@
+import logging
+import os
+import random
+from datetime import datetime
+import copy
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.cuda.amp import GradScaler
+import time
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+try:
+ import torch.utils.tensorboard as tensorboard
+except ImportError:
+ tensorboard = None
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from clap_module import create_model_and_transforms, trace_model, create_model
+from training.data import get_data
+from training.params import parse_args
+from training.distributed import is_master, init_distributed_device, world_info_from_env
+from training.logger import setup_logging
+from training.scheduler import cosine_lr
+from training.lp_train import train_one_epoch, evaluate
+from clap_module.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer
+from clap_module.utils import load_p, load_class_label
+from clap_module.linear_probe import LinearProbe
+
+
+def maintain_ckpts(args, startidx, all_idx_len):
+ for i in reversed(range(startidx, all_idx_len)):
+ if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
+ os.rename(
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
+ )
+ if os.path.exists(
+ os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
+ ):
+ os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
+ return
+
+
+def update_top_k_performance(
+ new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
+):
+ """
+ Record the top-k performance of the current epoch.
+ current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
+ """
+ if isinstance(new_metrics_inputs, (list, tuple)):
+ new_metrics_inputs = np.mean(new_metrics_inputs)
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, dict):
+ new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, (float, int)):
+ update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
+ sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
+ sorted_values = sorted(
+ current_top_k_ckpt_metrics.values(), reverse=bignumbetter
+ )
+ sorted_values_ = copy.deepcopy(sorted_values)
+ sorted_values.append(new_metrics_inputs)
+ sorted_values = sorted(sorted_values, reverse=bignumbetter)
+ sorted_values = sorted_values[:-1]
+
+ if sorted_values == sorted_values_:
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+ else:
+ for i in range(len(sorted_keys)):
+ if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
+ current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
+ update_flag[sorted_keys[i]] = True
+ for i in range(len(update_flag)):
+ if update_flag[i]:
+ maintain_ckpts(args, i, len(sorted_keys))
+ torch.save(
+ ckpt,
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ )
+ break
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+
+
+# def updateifNone(a, b):
+# a = b if None else a
+# return a
+
+
+def is_pretrained_params(n):
+ return (
+ n.startswith("clap_model.transformer")
+ or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
+ or n.startswith("clap_model.token_embedding")
+ or n.startswith("clap_model.ln_final")
+ or n.startswith("clap_model.logit_scale_t")
+ )
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
+
+def config_lp_optimizer(model, data, args):
+ # set wd-related params to 0 if use adam optimizer
+ if args.optimizer == "adam":
+ args.wd = 0
+ args.wd_pretrained = 0
+ args.wd_new = 0
+
+ in_clap = (
+ lambda n, p: n.startswith("clap_model")
+ )
+
+ named_parameters = list(model.named_parameters())
+
+ optimizer = {}
+ scheduler = {}
+
+ # freeze text encoder
+ text_freeze_parameters = [
+ p
+ for n, p in named_parameters
+ if n.startswith("clap_model.transformer")
+ or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
+ or n.startswith("clap_model.token_embedding")
+ or n.startswith("clap_model.ln_final")
+ ]
+
+ if args.freeze_text:
+ logging.info("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+
+ if not args.lp_freeze:
+ exclude = (
+ lambda n, p: p.ndim < 2
+ or "bn" in n
+ or "ln" in n
+ or "bias" in n
+ or "logit_scale" in n
+ )
+ include = lambda n, p: not exclude(n, p)
+
+ # (yusong): we do not split the learning rate anymore
+ # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad
+ gain_or_bias_params = [
+ p for n, p in named_parameters if exclude(n, p) and p.requires_grad
+ ]
+ # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad]
+ rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
+
+ if args.train_data is None:
+ optimizer = None
+ scheduler = None
+ else:
+ total_steps = data["train"].dataloader.num_batches * args.epochs
+
+ if args.split_opt:
+ for x in ["lr", "beta1", "beta2", "eps", "wd"]:
+ for y in ["_new", "_pretrained"]:
+ if getattr(args, x + y) is None:
+ setattr(args, x + y, getattr(args, x))
+
+ gain_or_bias_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ rest_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ gain_or_bias_new_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n))
+ ]
+ rest_new_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n))
+ ]
+
+ pretrained_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
+ {
+ "params": rest_pretrained_params,
+ "weight_decay": args.wd_pretrained,
+ },
+ ],
+ lr=args.lr_pretrained,
+ betas=(args.beta1_pretrained, args.beta2_pretrained),
+ eps=args.eps_pretrained,
+ momentum=args.momentum_pretrained,
+ optimizer_name=args.optimizer,
+ )
+ pretrained_params_scheduler = cosine_lr(
+ pretrained_params_optimizer,
+ args.lr_pretrained,
+ args.warmup,
+ total_steps,
+ )
+
+ new_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_new_params, "weight_decay": 0.0},
+ {"params": rest_new_params, "weight_decay": args.wd_new},
+ ],
+ lr=args.lr_new,
+ betas=(args.beta1_new, args.beta2_new),
+ eps=args.eps_new,
+ momentum=args.momentum_new,
+ optimizer_name=args.optimizer,
+ )
+ new_params_scheduler = cosine_lr(
+ new_params_optimizer, args.lr_new, args.warmup, total_steps
+ )
+
+ optimizer["text"] = pretrained_params_optimizer
+ optimizer["audio"] = new_params_optimizer
+ scheduler["text"] = pretrained_params_scheduler
+ scheduler["audio"] = new_params_scheduler
+
+ if args.horovod:
+ pretrained_params_optimizer = hvd.DistributedOptimizer(
+ pretrained_params_optimizer,
+ named_parameters=model.named_parameters(),
+ )
+ new_params_optimizer = hvd.DistributedOptimizer(
+ new_params_optimizer, named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0)
+ hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
+ else:
+
+ optimizer["clap"] = get_optimizer(
+ [
+ {"params": gain_or_bias_params, "weight_decay": 0.0},
+ {"params": rest_params, "weight_decay": args.wd},
+ ],
+ lr=args.lr,
+ betas=(args.beta1, args.beta2),
+ eps=args.eps,
+ momentum=args.momentum,
+ optimizer_name=args.optimizer,
+ )
+ scheduler["clap"] = cosine_lr(optimizer["clap"], args.lr, args.warmup, total_steps)
+
+ if args.horovod:
+ optimizer["clap"] = hvd.DistributedOptimizer(
+ optimizer["clap"], named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0)
+
+ # linear probe optimizer
+ else:
+ lp_params = [p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad]
+ lp_optim = get_optimizer(lp_params, lr=args.lp_lr, betas=(args.beta1, args.beta2), eps=args.eps, momentum=0.9,
+ optimizer_name=args.optimizer)
+ optimizer["lp"] = lp_optim
+
+ return optimizer, scheduler, text_freeze_parameters
+
+
+def main():
+ args = parse_args()
+
+ time.sleep(args.sleep)
+
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ args.amodel = args.amodel.replace("/", "-")
+ # download sizes.json file
+
+ # (yusong): the below two lines are for debug
+ # print("setting up faulthandler")
+ # faulthandler.register(10)
+
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ args.class_index_dict = load_class_label(args.class_label_path)
+
+ # get the name of the experiments
+ if args.name is None:
+ args.name = "-".join(
+ [
+ datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
+ f"linear_probe"
+ f"model_{args.amodel}",
+ f"lr_{args.lr}",
+ f"b_{args.batch_size}",
+ f"j_{args.workers}",
+ f"p_{args.precision}",
+ ]
+ )
+
+ # discover initial world args early so we can log properly
+ args.distributed = False
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+
+ if args.remotedata and is_master(args):
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ args.log_path = None
+ if is_master(args, local=args.log_local):
+ log_base_path = os.path.join(args.logs, args.name)
+ os.makedirs(log_base_path, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path, log_filename)
+
+ # avoid log dir in same name:
+ postfix = 0
+ while os.path.exists(args.log_path):
+ postfix += 1
+ log_base_path_new = log_base_path+'-'+str(postfix)
+ os.makedirs(log_base_path_new, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path_new, log_filename)
+ # print(
+ # "Error. Experiment already exists. Use --name {} to specify a new experiment."
+ # )
+ # return -1
+
+ # Set logger
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ setup_logging(args.log_path, args.log_level)
+
+ # fully initialize distributed device environment
+ device = init_distributed_device(args)
+
+ args.wandb = "wandb" in args.report_to or "all" in args.report_to
+ args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
+ if is_master(args):
+ args.tensorboard_path = (
+ os.path.join(args.logs, args.name, "tensorboard")
+ if args.tensorboard
+ else ""
+ )
+ args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
+ for dirname in [args.tensorboard_path, args.checkpoint_path]:
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ else:
+ args.tensorboard_path = ""
+ args.checkpoint_path = ""
+
+ if args.copy_codebase:
+ copy_codebase(args)
+
+ assert args.precision in ["amp", "fp16", "fp32"]
+ if args.precision == "fp16":
+ logging.warning(
+ "It is recommended to use AMP mixed-precision instead of FP16. "
+ "FP16 support needs further verification and tuning, especially for train."
+ )
+
+ if args.horovod:
+ logging.info(
+ f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ elif args.distributed:
+ logging.info(
+ f"Running in distributed mode with multiple processes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ else:
+ logging.info(f"Running with a single process. Device {args.device}.")
+
+ logging.info(f'openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}')
+
+ # Create CLAP model
+ clap_model, clap_model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ pretrained_audio=args.pretrained_audio,
+ pretrained_text=args.pretrained_text,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ args.lp_out_ch = len(list(args.class_index_dict.keys()))
+ # Linear Probe
+ logging.info(f"linear probe using mlp: {args.lp_mlp}")
+ logging.info(f"linear probe using freeze: {args.lp_freeze}")
+ logging.info(f"linear probe act layer: {args.lp_act}")
+ logging.info(f"linear probe out ch: {args.lp_out_ch}")
+ logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}")
+ logging.info(f"linear probe loss func: {args.lp_loss}")
+ logging.info(f"linear probe lp_metrics: {args.lp_metrics}")
+
+ model = LinearProbe(
+ clap_model,
+ mlp=args.lp_mlp, freeze=args.lp_freeze,
+ in_ch=512, out_ch=args.lp_out_ch,
+ act=args.lp_act
+ ) # in_ch is fixed (i.e., 512)
+ model = model.to(device)
+
+ if args.horovod:
+ with torch.no_grad():
+ for param in model.parameters():
+ param.set_(param.contiguous())
+
+ if args.trace:
+ model = trace_model(model, batch_size=args.batch_size, device=device)
+
+ if is_master(args):
+ logging.info("Linear Probe CLAP Model:")
+ logging.info(f"{str(clap_model)}")
+ logging.info("Params:")
+ params_file = os.path.join(args.logs, args.name, "params.txt")
+ with open(params_file, "w") as f:
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ logging.info(f" {name}: {val}")
+ f.write(f"{name}: {val}\n")
+
+ if args.distributed and not args.horovod:
+ if args.use_bn_sync:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ ddp_args = {}
+ if args.ddp_static_graph:
+ # this doesn't exist in older PyTorch, arg only added if enabled
+ ddp_args["static_graph"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[device], find_unused_parameters=True, **ddp_args
+ )
+
+ data = get_data(args, clap_model_cfg)
+ assert len(data), "At least one train or eval dataset must be specified."
+ if args.trace:
+ assert "train" not in data, "Cannot train with traced model"
+
+
+ optimizer, scheduler, text_freeze_parameters = config_lp_optimizer(model, data, args)
+
+
+ scaler = GradScaler() if args.precision == "amp" else None
+
+ # optionally resume from a checkpoint
+ start_epoch = 0
+ if args.resume is not None:
+ if os.path.isfile(args.resume):
+ checkpoint = torch.load(args.resume, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if not args.distributed and next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module.") :]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ if args.split_opt:
+ if optimizer is not None:
+ for k, o_ in optimizer.items():
+ o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if scaler is not None and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ logging.info(
+ f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ logging.info(
+ f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ else:
+ logging.info("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
+ args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
+ writer = None
+ if args.save_logs and args.tensorboard:
+ assert tensorboard is not None, "Please install tensorboard."
+ writer = tensorboard.SummaryWriter(args.tensorboard_path)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, "Please install wandb."
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ wandb.init(
+ project="clap",
+ notes=args.wandb_notes,
+ name=args.wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ if args.debug:
+ wandb.watch(model, log="all")
+ wandb.save(params_file)
+ logging.debug("Finished loading wandb.")
+
+ if "train" not in data:
+ evaluate(model, data, start_epoch, args, writer)
+ return
+ elif start_epoch == 0 and "val" in data and not args.no_eval:
+ evaluate(model, data, 0, args, writer)
+ if args.save_top_performance:
+ current_top_k_ckpt_metrics = {
+ i: 0 for i in range(args.save_top_performance)
+ } # initialize the top-k metric for ckpts to 0
+
+ for epoch in range(start_epoch, args.epochs):
+ # freeze the text param after (include) args.freeze_text_after, this is -1 by default
+ if epoch == args.freeze_text_after:
+ print("Text pretrained parameters are freezed since this epoch.")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ if is_master(args):
+ logging.info(f"Start epoch {epoch}")
+
+ train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
+ completed_epoch = epoch + 1
+
+ if any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) and not args.no_eval:
+ metrics = evaluate(model, data, completed_epoch, args, writer)
+ if args.save_top_performance:
+ top_k_dataset = args.top_k_checkpoint_select_dataset
+ top_k_metric = args.top_k_checkpoint_select_metric
+ filtered_metrics = [
+ v
+ for k, v in metrics.items()
+ if top_k_metric in k and top_k_dataset in k
+ ] # check all R@10 metrics (all dataset) and use it to update the ckpt
+ # Saving checkpoints.
+ if args.save_logs:
+ opt_dict = {
+ k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
+ }
+ checkpoint_dict = {
+ "epoch": completed_epoch,
+ "name": args.name,
+ "state_dict": model.state_dict(),
+ }
+ checkpoint_dict.update(opt_dict)
+ if scaler is not None:
+ checkpoint_dict["scaler"] = scaler.state_dict()
+
+ if completed_epoch == args.epochs or (
+ args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
+ ):
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
+ )
+ if args.save_most_recent:
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
+ )
+ if args.save_top_performance and not args.no_eval:
+ update_top_k_performance(
+ filtered_metrics,
+ current_top_k_ckpt_metrics,
+ args,
+ checkpoint_dict,
+ bignumbetter=True,
+ )
+
+ if args.wandb and is_master(args):
+ wandb.finish()
+
+
+def copy_codebase(args):
+ from shutil import copytree, ignore_patterns
+
+ new_code_path = os.path.join(args.logs, args.name, "code")
+ if os.path.exists(new_code_path):
+ print(
+ f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
+ )
+ return -1
+ print(f"Copying codebase to {new_code_path}")
+ current_code_path = os.path.realpath(__file__)
+ for _ in range(3):
+ current_code_path = os.path.dirname(current_code_path)
+ copytree(
+ current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
+ )
+ print("Done copying code.")
+ return 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/laion_clap/training/lp_train.py b/src/laion_clap/training/lp_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d686336c824dd34b45f056c414d761540141a46f
--- /dev/null
+++ b/src/laion_clap/training/lp_train.py
@@ -0,0 +1,292 @@
+import json
+import logging
+import math
+import os
+import time
+from contextlib import suppress
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from clap_module import LPLoss, LPMetrics, lp_gather_features
+from clap_module.utils import do_mixup, get_mix_lambda
+from .distributed import is_master
+from .zero_shot import zero_shot_eval
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def unwrap_model(model):
+ if hasattr(model, "module"):
+ return model.module
+ else:
+ return model
+
+
+def train_one_epoch(
+ model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None, extra_suffix=""
+):
+ device = torch.device(args.device)
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ model.train()
+ loss = LPLoss(args.lp_loss)
+
+ dataloader, sampler = data["train"].dataloader, data["train"].sampler
+ if args.distributed and sampler is not None:
+ sampler.set_epoch(epoch)
+ num_batches_per_epoch = dataloader.num_batches
+ sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
+
+ # for toy dataset
+ if args.dataset_type == "toy":
+ dataloader.dataset.generate_queue()
+
+ loss_m = AverageMeter()
+ batch_time_m = AverageMeter()
+ data_time_m = AverageMeter()
+ end = time.time()
+
+ for i, batch in enumerate(dataloader):
+ step = num_batches_per_epoch * epoch + i
+
+ if isinstance(scheduler, dict):
+ for s in scheduler.values():
+ s(step)
+ else:
+ scheduler(step)
+
+ audio = batch # contains mel_spec, wavform, and longer list
+ class_label = batch['class_label']
+ # audio = audio.to(device=device, non_blocking=True)
+ class_label = class_label.to(device=device, non_blocking=True)
+
+ if args.mixup:
+ # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146
+ mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(audio["waveform"]))).to(device)
+ class_label = do_mixup(class_label, mix_lambda)
+ else:
+ mix_lambda = None
+
+ data_time_m.update(time.time() - end)
+ if isinstance(optimizer, dict):
+ for o_ in optimizer.values():
+ o_.zero_grad()
+ else:
+ optimizer.zero_grad()
+
+ with autocast():
+ pred = model(audio, mix_lambda=mix_lambda, device=device)
+ total_loss = loss(pred, class_label)
+
+ if isinstance(optimizer, dict):
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ for o_ in optimizer.values():
+ if args.horovod:
+ o_.synchronize()
+ scaler.unscale_(o_)
+ with o_.skip_synchronize():
+ scaler.step(o_)
+ else:
+ scaler.step(o_)
+ scaler.update()
+ else:
+ total_loss.backward()
+ for o_ in optimizer.values():
+ o_.step()
+ else:
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ if args.horovod:
+ optimizer.synchronize()
+ scaler.unscale_(optimizer)
+ with optimizer.skip_synchronize():
+ scaler.step(optimizer)
+ else:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ total_loss.backward()
+ optimizer.step()
+
+ # Note: we clamp to 4.6052 = ln(100), as in the original paper.
+ with torch.no_grad():
+ unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100))
+ unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100))
+
+ batch_time_m.update(time.time() - end)
+ end = time.time()
+ batch_count = i + 1
+
+ if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
+ if isinstance(audio, dict):
+ batch_size = len(audio["waveform"])
+ else:
+ batch_size = len(audio)
+ num_samples = batch_count * batch_size * args.world_size
+ samples_per_epoch = dataloader.num_samples
+ percent_complete = 100.0 * batch_count / num_batches_per_epoch
+
+ # NOTE loss is coarsely sampled, just master node and per log update
+ loss_m.update(total_loss.item(), batch_size)
+ if isinstance(optimizer, dict):
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}"
+ )
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
+ }
+ else:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "lr": optimizer.param_groups[0]["lr"],
+ }
+ for name, val in log_data.items():
+ name = f"train{extra_suffix}/{name}"
+ if tb_writer is not None:
+ tb_writer.add_scalar(name, val, step)
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ wandb.log({name: val, "step": step})
+
+ # resetting batch / data time meters per log window
+ batch_time_m.reset()
+ data_time_m.reset()
+ # end for
+
+def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
+ metrics = {}
+ if not args.parallel_eval:
+ if not is_master(args):
+ return metrics
+ device = torch.device(args.device)
+ model.eval()
+
+ # CHANGE
+ # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
+ # metrics.update(zero_shot_metrics)
+ if is_master(args):
+ print('Evaluating...')
+ metric_names = args.lp_metrics.split(',')
+ eval_tool = LPMetrics(metric_names=metric_names)
+
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ if "val" in data and (
+ args.val_frequency
+ and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
+ ):
+ if args.parallel_eval:
+ dataloader, sampler = data["val"].dataloader, data["val"].sampler
+ if args.distributed and sampler is not None:
+ sampler.set_epoch(epoch)
+ samples_per_val = dataloader.num_samples
+ else:
+ dataloader = data["val"].dataloader
+ num_samples = 0
+ samples_per_val = dataloader.num_samples
+
+ eval_info = {
+ 'pred': [],
+ 'target': []
+ }
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ audio = batch # contains mel_spec, wavform, and longer list
+ class_label = batch['class_label']
+
+ # audio = audio.to(device=device, non_blocking=True)
+ class_label = class_label.to(device=device, non_blocking=True)
+
+ with autocast():
+ pred = model(audio, device=device)
+ if args.parallel_eval:
+ pred, class_label = lp_gather_features(pred, class_label, args.world_size, args.horovod)
+ eval_info['pred'].append(pred)
+ eval_info['target'].append(class_label)
+
+ num_samples += class_label.shape[0]
+
+ if (i % 100) == 0: # and i != 0:
+ logging.info(
+ f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
+ )
+
+ if is_master(args):
+ eval_info['pred'] = torch.cat(eval_info['pred'], 0).cpu()
+ eval_info['target'] = torch.cat(eval_info['target'], 0).cpu()
+ metric_dict = eval_tool.evaluate_mertics(eval_info['pred'], eval_info['target'])
+ metrics.update(metric_dict)
+ if "epoch" not in metrics.keys():
+ metrics.update({"epoch": epoch})
+
+ if is_master(args):
+ if not metrics:
+ return metrics
+
+ logging.info(
+ f"Eval Epoch: {epoch} "
+ + "\n".join(
+ [
+ "\t".join([f"{m}: {round(metrics[m], 4):.4f}" ])
+ for m in metrics
+ ]
+ )
+ )
+ if args.save_logs:
+ for name, val in metrics.items():
+ if tb_writer is not None:
+ tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch)
+
+ with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
+ f.write(json.dumps(metrics))
+ f.write("\n")
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ for name, val in metrics.items():
+ wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch})
+
+ return metrics
+ else:
+ return metrics
diff --git a/src/laion_clap/training/main.py b/src/laion_clap/training/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c48b66e2287785961501dfd75f2b6f5d331c245
--- /dev/null
+++ b/src/laion_clap/training/main.py
@@ -0,0 +1,597 @@
+import logging
+import os
+import random
+from datetime import datetime
+import copy
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.cuda.amp import GradScaler
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+try:
+ import torch.utils.tensorboard as tensorboard
+except ImportError:
+ tensorboard = None
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from clap_module import create_model_and_transforms, trace_model, create_model
+from training.data import get_data
+from training.distributed import is_master, init_distributed_device, world_info_from_env
+from training.logger import setup_logging
+from training.params import parse_args
+from training.scheduler import cosine_lr
+from training.train import train_one_epoch, evaluate
+from clap_module.utils import dataset_split, get_optimizer
+
+
+def maintain_ckpts(args, startidx, all_idx_len):
+ for i in reversed(range(startidx, all_idx_len)):
+ if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
+ os.rename(
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
+ )
+ if os.path.exists(
+ os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
+ ):
+ os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
+ return
+
+
+def update_top_k_performance(
+ new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
+):
+ """
+ Record the top-k performance of the current epoch.
+ current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
+ """
+ if isinstance(new_metrics_inputs, (list, tuple)):
+ new_metrics_inputs = np.mean(new_metrics_inputs)
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, dict):
+ new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, (float, int)):
+ update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
+ sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
+ sorted_values = sorted(
+ current_top_k_ckpt_metrics.values(), reverse=bignumbetter
+ )
+ sorted_values_ = copy.deepcopy(sorted_values)
+ sorted_values.append(new_metrics_inputs)
+ sorted_values = sorted(sorted_values, reverse=bignumbetter)
+ sorted_values = sorted_values[:-1]
+
+ if sorted_values == sorted_values_:
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+ else:
+ for i in range(len(sorted_keys)):
+ if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
+ current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
+ update_flag[sorted_keys[i]] = True
+ for i in range(len(update_flag)):
+ if update_flag[i]:
+ maintain_ckpts(args, i, len(sorted_keys))
+ torch.save(
+ ckpt,
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ )
+ break
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+
+
+# def updateifNone(a, b):
+# a = b if None else a
+# return a
+
+
+def is_pretrained_params(n):
+ return (
+ n.startswith("transformer")
+ or n in ["positional_embedding", "text_projection"]
+ or n.startswith("token_embedding")
+ or n.startswith("ln_final")
+ or n.startswith("logit_scale_t")
+ )
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
+
+
+def main():
+ args = parse_args()
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ args.amodel = args.amodel.replace("/", "-")
+ # download sizes.json file
+
+ # (yusong): the below two lines are for debug
+ # print("setting up faulthandler")
+ # faulthandler.register(10)
+
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart":
+ assert (
+ args.pretrained == "" or args.pretrained is None
+ ), "bert/roberta/bart text encoder does not support pretrained models."
+
+ # get the name of the experiments
+ if args.name is None:
+ args.name = "-".join(
+ [
+ datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
+ f"model_{args.amodel}",
+ f"lr_{args.lr}",
+ f"b_{args.batch_size}",
+ f"j_{args.workers}",
+ f"p_{args.precision}",
+ ]
+ )
+
+ # discover initial world args early so we can log properly
+ args.distributed = False
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+
+ if args.remotedata and is_master(args):
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ args.log_path = None
+ if is_master(args, local=args.log_local):
+ log_base_path = os.path.join(args.logs, args.name)
+ os.makedirs(log_base_path, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path, log_filename)
+ if os.path.exists(args.log_path):
+ print(
+ "Error. Experiment already exists. Use --name {} to specify a new experiment."
+ )
+ return -1
+
+ # Set logger
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ setup_logging(args.log_path, args.log_level)
+
+ # fully initialize distributed device environment
+ device = init_distributed_device(args)
+
+ args.wandb = "wandb" in args.report_to or "all" in args.report_to
+ args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
+ if is_master(args):
+ args.tensorboard_path = (
+ os.path.join(args.logs, args.name, "tensorboard")
+ if args.tensorboard
+ else ""
+ )
+ args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
+ for dirname in [args.tensorboard_path, args.checkpoint_path]:
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ else:
+ args.tensorboard_path = ""
+ args.checkpoint_path = ""
+
+ if args.copy_codebase:
+ copy_codebase(args)
+
+ assert args.precision in ["amp", "fp16", "fp32"]
+ if args.precision == "fp16":
+ logging.warning(
+ "It is recommended to use fp32 mixed-precision instead of FP16 and AMP in this model. "
+ "They will cause NaN loss and NaN gradients. "
+ "FP16 and AMP support needs further verification and tuning, especially for train."
+ )
+
+ if args.horovod:
+ logging.info(
+ f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ elif args.distributed:
+ logging.info(
+ f"Running in distributed mode with multiple processes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ else:
+ logging.info(f"Running with a single process. Device {args.device}.")
+
+ logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}")
+
+ model, model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=True,
+ pretrained_audio=args.pretrained_audio,
+ pretrained_text=args.pretrained_text,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ if args.horovod:
+ with torch.no_grad():
+ for param in model.parameters():
+ param.set_(param.contiguous())
+
+ if args.trace:
+ model = trace_model(model, batch_size=args.batch_size, device=device)
+
+ if is_master(args):
+ logging.info("Model:")
+ logging.info(f"{str(model)}")
+ logging.info("Params:")
+ params_file = os.path.join(args.logs, args.name, "params.txt")
+ with open(params_file, "w") as f:
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ logging.info(f" {name}: {val}")
+ f.write(f"{name}: {val}\n")
+
+ if args.distributed and not args.horovod:
+ if args.use_bn_sync:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ ddp_args = {}
+ if args.ddp_static_graph:
+ # this doesn't exist in older PyTorch, arg only added if enabled
+ ddp_args["static_graph"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[device], find_unused_parameters=True, **ddp_args
+ )
+
+ data = get_data(args, model_cfg)
+ assert len(data), "At least one train or eval dataset must be specified."
+ if args.trace:
+ assert "train" not in data, "Cannot train with traced model"
+
+ exclude = (
+ lambda n, p: p.ndim < 2
+ or "bn" in n
+ or "ln" in n
+ or "bias" in n
+ or "logit_scale" in n
+ )
+ include = lambda n, p: not exclude(n, p)
+
+ named_parameters = list(model.named_parameters())
+
+ # freeze text encoder
+ text_freeze_parameters = [
+ p
+ for n, p in named_parameters
+ if 'text_branch' in n
+ ]
+
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+
+ gain_or_bias_params = [
+ p for n, p in named_parameters if exclude(n, p) and p.requires_grad
+ ]
+ rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
+
+ # set wd-related params to 0 if use adam optimizer
+ if args.optimizer == "adam":
+ args.wd = 0
+ args.wd_pretrained = 0
+ args.wd_new = 0
+
+ if args.train_data is None:
+ optimizer = None
+ scheduler = None
+ else:
+ total_steps = data["train"].dataloader.num_batches * args.epochs
+
+ if args.split_opt:
+ for x in ["lr", "beta1", "beta2", "eps", "wd"]:
+ for y in ["_new", "_pretrained"]:
+ if getattr(args, x + y) is None:
+ setattr(args, x + y, getattr(args, x))
+
+ gain_or_bias_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ rest_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ gain_or_bias_new_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n))
+ ]
+ rest_new_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n))
+ ]
+ pretrained_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
+ {
+ "params": rest_pretrained_params,
+ "weight_decay": args.wd_pretrained,
+ },
+ ],
+ lr=args.lr_pretrained,
+ betas=(args.beta1_pretrained, args.beta2_pretrained),
+ eps=args.eps_pretrained,
+ momentum=args.momentum_pretrained,
+ optimizer_name=args.optimizer,
+ )
+ pretrained_params_scheduler = cosine_lr(
+ pretrained_params_optimizer,
+ args.lr_pretrained,
+ args.warmup,
+ total_steps,
+ )
+ new_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_new_params, "weight_decay": 0.0},
+ {"params": rest_new_params, "weight_decay": args.wd_new},
+ ],
+ lr=args.lr_new,
+ betas=(args.beta1_new, args.beta2_new),
+ eps=args.eps_new,
+ momentum=args.momentum_new,
+ optimizer_name=args.optimizer,
+ )
+
+ new_params_scheduler = cosine_lr(
+ new_params_optimizer, args.lr_new, args.warmup, total_steps
+ )
+
+ optimizer = {
+ "pretrained": pretrained_params_optimizer,
+ "new": new_params_optimizer,
+ }
+ scheduler = {
+ "pretrained": pretrained_params_scheduler,
+ "new": new_params_scheduler,
+ }
+
+ if args.horovod:
+ pretrained_params_optimizer = hvd.DistributedOptimizer(
+ pretrained_params_optimizer,
+ named_parameters=model.named_parameters(),
+ )
+ new_params_optimizer = hvd.DistributedOptimizer(
+ new_params_optimizer, named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0)
+ hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
+ else:
+ optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_params, "weight_decay": 0.0},
+ {"params": rest_params, "weight_decay": args.wd},
+ ],
+ lr=args.lr,
+ betas=(args.beta1, args.beta2),
+ eps=args.eps,
+ momentum=args.momentum,
+ optimizer_name=args.optimizer,
+ )
+
+ scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
+
+ if args.horovod:
+ optimizer = hvd.DistributedOptimizer(
+ optimizer, named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(optimizer, root_rank=0)
+
+ scaler = GradScaler() if args.precision == "amp" else None
+
+ # optionally resume from a checkpoint
+ start_epoch = 0
+ if args.resume is not None:
+ if os.path.isfile(args.resume):
+ checkpoint = torch.load(args.resume, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if not args.distributed and next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module.") :]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ if args.split_opt:
+ if optimizer is not None:
+ for k, o_ in optimizer.items():
+ o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if scaler is not None and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ logging.info(
+ f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ logging.info(
+ f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ else:
+ logging.info("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
+ args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
+ writer = None
+ if args.save_logs and args.tensorboard:
+ assert tensorboard is not None, "Please install tensorboard."
+ writer = tensorboard.SummaryWriter(args.tensorboard_path)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, "Please install wandb."
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ wandb.init(
+ entity="clap",
+ project="clap",
+ notes=args.wandb_notes,
+ name=args.wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ if args.debug:
+ wandb.watch(model, log="all")
+ wandb.save(params_file)
+ logging.debug("Finished loading wandb.")
+
+ if "train" not in data:
+ evaluate(model, data, start_epoch, args, writer)
+ return
+ elif start_epoch == 0 and "val" in data and not args.no_eval:
+ evaluate(model, data, 0, args, writer)
+ # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug
+ if args.save_top_performance:
+ current_top_k_ckpt_metrics = {
+ i: 0 for i in range(args.save_top_performance)
+ } # initialize the top-k metric for ckpts to 0
+
+ # print(f'rank {args.rank}, Start Training') # (yusong): for debug
+ for epoch in range(start_epoch, args.epochs):
+ # freeze the text param after (include) args.freeze_text_after, this is -1 by default
+ if epoch == args.freeze_text_after:
+ print("Text pretrained parameters are freezed since this epoch.")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ if is_master(args):
+ logging.info(f"Start epoch {epoch}")
+
+ train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
+ completed_epoch = epoch + 1
+
+ if (
+ any(v in data for v in ("val", "imagenet-val", "imagenet-v2"))
+ and not args.no_eval
+ ):
+ metrics = evaluate(model, data, completed_epoch, args, writer)
+ if args.save_top_performance:
+ top_k_dataset = args.top_k_checkpoint_select_dataset
+ top_k_metric = args.top_k_checkpoint_select_metric
+ filtered_metrics = [
+ v
+ for k, v in metrics.items()
+ if top_k_metric in k and top_k_dataset in k
+ ] # check all R@10 metrics (all dataset) and use it to update the ckpt
+ # Saving checkpoints.
+ if args.save_logs:
+ if args.split_opt:
+ opt_dict = {
+ k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
+ }
+ else:
+ opt_dict = {"optimizer": optimizer.state_dict()}
+ checkpoint_dict = {
+ "epoch": completed_epoch,
+ "name": args.name,
+ "state_dict": model.state_dict(),
+ }
+ checkpoint_dict.update(opt_dict)
+ if scaler is not None:
+ checkpoint_dict["scaler"] = scaler.state_dict()
+
+ if completed_epoch == args.epochs or (
+ args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
+ ):
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
+ )
+ if args.save_most_recent:
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
+ )
+ if args.save_top_performance and not args.no_eval:
+ update_top_k_performance(
+ filtered_metrics,
+ current_top_k_ckpt_metrics,
+ args,
+ checkpoint_dict,
+ bignumbetter=True,
+ )
+
+ if args.wandb and is_master(args):
+ wandb.finish()
+
+
+def copy_codebase(args):
+ from shutil import copytree, ignore_patterns
+
+ new_code_path = os.path.join(args.logs, args.name, "code")
+ if os.path.exists(new_code_path):
+ print(
+ f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
+ )
+ return -1
+ print(f"Copying codebase to {new_code_path}")
+ current_code_path = os.path.realpath(__file__)
+ for _ in range(3):
+ current_code_path = os.path.dirname(current_code_path)
+ copytree(
+ current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
+ )
+ print("Done copying code.")
+ return 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/laion_clap/training/params.py b/src/laion_clap/training/params.py
new file mode 100644
index 0000000000000000000000000000000000000000..84cd3b43104007a14835e4de9cd4521899ba6345
--- /dev/null
+++ b/src/laion_clap/training/params.py
@@ -0,0 +1,567 @@
+import argparse
+
+
+def get_default_params(model_name):
+ # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
+ model_name = model_name.lower()
+ if "vit" in model_name:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
+ else:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--train-data",
+ type=str,
+ default=None,
+ help="Path to h5 filewith training data",
+ )
+ parser.add_argument(
+ "--val-data",
+ type=str,
+ default=None,
+ help="Path to h5 file with validation data",
+ )
+ parser.add_argument(
+ "--freeze-text",
+ default=False,
+ action="store_true",
+ help="if you need to freeze the text encoder, make this True",
+ )
+ parser.add_argument(
+ "--freeze-text-after",
+ type=int,
+ default=-1,
+ help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
+ )
+ parser.add_argument(
+ "--train-ipc",
+ type=str,
+ default=None,
+ help="Path to npy file of the number of instance per class in training data",
+ )
+ parser.add_argument(
+ "--val-ipc",
+ type=str,
+ default=None,
+ help="Path to npy file of the number of instance per class in validation data",
+ )
+ parser.add_argument(
+ "--train-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Required for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--val-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Useful for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--dataset-type",
+ choices=["webdataset", "csv", "auto", "toy"],
+ default="auto",
+ help="Which type of dataset to process.",
+ )
+ parser.add_argument(
+ "--csv-separator",
+ type=str,
+ default="\t",
+ help="For csv-like datasets, which separator to use.",
+ )
+ parser.add_argument(
+ "--csv-img-key",
+ type=str,
+ default="filepath",
+ help="For csv-like datasets, the name of the key for the image paths.",
+ )
+ parser.add_argument(
+ "--csv-caption-key",
+ type=str,
+ default="title",
+ help="For csv-like datasets, the name of the key for the captions.",
+ )
+ parser.add_argument(
+ "--imagenet-val",
+ type=str,
+ default=None,
+ help="Path to imagenet val set for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--imagenet-v2",
+ type=str,
+ default=None,
+ help="Path to imagenet v2 for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--datasetnames",
+ nargs="+",
+ default=None,
+ help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
+ )
+ parser.add_argument(
+ "--full-train-dataset",
+ nargs="+",
+ default=None,
+ help="Which dataset will be trained with all the subsets. (train+test)",
+ )
+ parser.add_argument(
+ "--exclude-eval-dataset",
+ nargs="+",
+ default=None,
+ help="Which dataset will be excluded with evaluation",
+ )
+ parser.add_argument(
+ "--datasetinfos",
+ nargs="+",
+ default=None,
+ help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
+ )
+ parser.add_argument(
+ "--dataset-proportion",
+ type=float,
+ default=1.0,
+ help="How much proportion of dataset we want to train.",
+ )
+ parser.add_argument(
+ "--remotedata",
+ default=False,
+ action="store_true",
+ help="if the dataset is remote, set this flag",
+ )
+ parser.add_argument(
+ "--class-label-path",
+ type=str,
+ default=None,
+ help="The path of the class label pickle or csv.",
+ )
+ parser.add_argument(
+ "--datasetpath",
+ type=str,
+ default="/mnt/audio_clip/webdataset_tar",
+ help="The path to the dataset",
+ )
+ parser.add_argument(
+ "--logs",
+ type=str,
+ default="./logs/",
+ help="Where to store tensorboard logs. Use None to avoid storing logs.",
+ )
+ parser.add_argument(
+ "--log-local",
+ action="store_true",
+ default=False,
+ help="log files on local master, otherwise global master only.",
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default=None,
+ help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
+ )
+ parser.add_argument(
+ "--workers", type=int, default=1, help="Number of workers per GPU."
+ )
+ parser.add_argument(
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
+ )
+ parser.add_argument(
+ "--epochs", type=int, default=32, help="Number of epochs to train for."
+ )
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
+ parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
+ parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
+ parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
+ parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
+ parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
+
+ parser.add_argument(
+ "--split-opt",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--lr-pretrained", type=float, default=None, help="Learning rate for text."
+ )
+ parser.add_argument(
+ "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
+ )
+ parser.add_argument(
+ "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
+ )
+ parser.add_argument(
+ "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
+ )
+ parser.add_argument(
+ "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
+ )
+ parser.add_argument(
+ "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
+ )
+ parser.add_argument(
+ "--lr-new", type=float, default=None, help="Learning rate for audio."
+ )
+ parser.add_argument(
+ "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
+ )
+ parser.add_argument(
+ "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
+ )
+ parser.add_argument(
+ "--eps-new", type=float, default=None, help="Adam epsilon for audio."
+ )
+ parser.add_argument(
+ "--wd-new", type=float, default=0.2, help="Weight decay for audio."
+ )
+ parser.add_argument(
+ "--momentum-new", type=float, default=0.9, help="Momentum for audio."
+ )
+ parser.add_argument(
+ "--warmup", type=int, default=10000, help="Number of steps to warmup for."
+ )
+ parser.add_argument(
+ "--use-bn-sync",
+ default=False,
+ action="store_true",
+ help="Whether to use batch norm sync.",
+ )
+ parser.add_argument(
+ "--skip-scheduler",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--save-frequency", type=int, default=1, help="How often to save checkpoints."
+ )
+ parser.add_argument(
+ "--save-top-performance",
+ type=int,
+ default=0,
+ help="Save the top x performance weights if the value >0",
+ )
+ parser.add_argument(
+ "--save-most-recent",
+ action="store_true",
+ default=False,
+ help="Always save the most recent model trained to epoch_latest.pt.",
+ )
+ parser.add_argument(
+ "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
+ )
+ parser.add_argument(
+ "--val-frequency",
+ type=int,
+ default=1,
+ help="How often to run evaluation with val data.",
+ )
+ parser.add_argument(
+ "--resume",
+ default=None,
+ type=str,
+ help="path to latest checkpoint (default: none)",
+ )
+ parser.add_argument(
+ "--precision",
+ choices=["amp", "fp16", "fp32"],
+ default="amp",
+ help="Floating point precision.",
+ )
+ parser.add_argument(
+ "--amodel",
+ type=str,
+ default="RN50",
+ help="Name of the audio backbone to use.",
+ )
+ parser.add_argument(
+ "--tmodel",
+ type=str,
+ default="transformer",
+ help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
+ )
+ parser.add_argument(
+ "--pretrained-audio",
+ default="",
+ type=str,
+ help="Use a pretrained audio model weights for the audio encoder of CLAP",
+ )
+ parser.add_argument(
+ "--pretrained-text",
+ default="",
+ type=str,
+ help="Use a pretrained text model weights for the text encoder of CLAP",
+ )
+ parser.add_argument(
+ "--pretrained",
+ default="",
+ type=str,
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
+ )
+ parser.add_argument(
+ "--pretrained-image",
+ default=False,
+ action="store_true",
+ help="Load imagenet pretrained weights for image tower backbone if available.",
+ )
+ parser.add_argument(
+ "--lock-image",
+ default=False,
+ action="store_true",
+ help="Lock full image tower by disabling gradients.",
+ )
+ parser.add_argument(
+ "--lock-image-unlocked-groups",
+ type=int,
+ default=0,
+ help="Leave last n image tower layer groups unlocked.",
+ )
+ parser.add_argument(
+ "--lock-image-freeze-bn-stats",
+ default=False,
+ action="store_true",
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
+ )
+ parser.add_argument(
+ "--local-loss",
+ default=False,
+ action="store_true",
+ help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
+ )
+ parser.add_argument(
+ "--gather-with-grad",
+ default=False,
+ action="store_true",
+ help="enable full distributed gradient for feature gather",
+ )
+ parser.add_argument(
+ "--force-quick-gelu",
+ default=False,
+ action="store_true",
+ help="Force use of QuickGELU activation for non-OpenAI transformer models.",
+ )
+ parser.add_argument(
+ "--torchscript",
+ default=False,
+ action="store_true",
+ help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
+ )
+ parser.add_argument(
+ "--trace",
+ default=False,
+ action="store_true",
+ help="torch.jit.trace the model for inference / eval only",
+ )
+ # arguments for distributed training
+ parser.add_argument(
+ "--dist-url",
+ default="env://",
+ type=str,
+ help="url used to set up distributed training",
+ )
+ parser.add_argument(
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
+ )
+ parser.add_argument(
+ "--report-to",
+ default="",
+ type=str,
+ help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
+ )
+ parser.add_argument(
+ "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
+ )
+ parser.add_argument(
+ "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
+ )
+ parser.add_argument(
+ "--debug",
+ default=False,
+ action="store_true",
+ help="If true, more information is logged.",
+ )
+ parser.add_argument(
+ "--copy-codebase",
+ default=False,
+ action="store_true",
+ help="If true, we copy the entire base on the log diretory, and execute from there.",
+ )
+ parser.add_argument(
+ "--horovod",
+ default=False,
+ action="store_true",
+ help="Use horovod for distributed training.",
+ )
+ parser.add_argument(
+ "--ddp-static-graph",
+ default=False,
+ action="store_true",
+ help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
+ )
+ parser.add_argument(
+ "--no-set-device-rank",
+ default=False,
+ action="store_true",
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+ )
+ parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
+
+ parser.add_argument(
+ "--top-k-checkpoint-select-dataset",
+ type=str,
+ default="all",
+ help="The dataset of selecting top-k checkpoint.",
+ )
+
+ # @R10, @R@5, @R1, mAP@10
+ parser.add_argument(
+ "--top-k-checkpoint-select-metric",
+ type=str,
+ default="_R@10",
+ help="The metric for selecting top-k checkpoint.",
+ )
+ parser.add_argument(
+ "--openai-model-cache-dir",
+ type=str,
+ default="~/.cache/clip",
+ help="Directory to download OpenAI models.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="adamw",
+ help="can be AdamW or SGD",
+ )
+ parser.add_argument(
+ "--parallel-eval",
+ default=False,
+ action="store_true",
+ help="Eval in parallel (multi-GPU, multi-node).",
+ )
+
+ parser.add_argument(
+ "--no-eval",
+ default=False,
+ action="store_true",
+ help="Training without evaluation.",
+ )
+
+ parser.add_argument(
+ "--lp-mlp",
+ default=False,
+ action="store_true",
+ help="Linear Probe using MLP layer or not.",
+ )
+
+ parser.add_argument(
+ "--lp-freeze",
+ default=False,
+ action="store_true",
+ help="Linear Probe using Freeze CLAP or not",
+ )
+
+ parser.add_argument(
+ "--lp-act",
+ default="None",
+ type=str,
+ help="Options are ['relu','elu','prelu','softmax','sigmoid']",
+ )
+
+ parser.add_argument(
+ "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
+ )
+
+ parser.add_argument(
+ "--lp-metrics",
+ type=str,
+ default="map,mauc,acc",
+ help="Metrics of Linear Probe.",
+ )
+
+ parser.add_argument(
+ "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
+ )
+ parser.add_argument(
+ "--kappa", type=float, default=0,
+ help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss"
+ )
+
+ parser.add_argument(
+ "--data-filling",
+ type=str,
+ default="pad",
+ help="type of data filling when the audio length is shorter than the max length."
+ "Can be one of the following: repeat, repeatpad, pad",
+ )
+ parser.add_argument(
+ "--data-truncating",
+ type=str,
+ default="rand_trunc",
+ help="type of data truncation when the audio length is longer than the max length."
+ "Can be one of the following: rand_trunc, fusion",
+ )
+
+ parser.add_argument(
+ "--clap-mlploss",
+ default=False,
+ action="store_true",
+ help="Using MLP loss for CLAP model or not",
+ )
+
+ parser.add_argument(
+ "--wandb-id",
+ type=str,
+ default=None,
+ help="the id of wandb experiment to restore.",
+ )
+
+ parser.add_argument(
+ "--sleep", type=float, default=0, help="sleep n seconds before start training"
+ )
+
+ # variable length processing
+ parser.add_argument(
+ "--enable-fusion",
+ default=False,
+ action="store_true",
+ help="Enable feature funsion for variable-length data",
+ )
+
+ parser.add_argument(
+ "--fusion-type",
+ type=str,
+ default='None',
+ help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
+ )
+
+ parser.add_argument(
+ "--mixup",
+ default=False,
+ action="store_true",
+ help="Enable mixup in finetuning training.",
+ )
+ parser.add_argument(
+ "--text-augment-selection",
+ type=str,
+ default=None,
+ help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
+ )
+ parser.add_argument(
+ "--prefetch-factor",
+ type=int,
+ default=None,
+ help="The prefetch factor for dataloader. Larger value will use more memory and CPU but faster.",
+ )
+
+ args = parser.parse_args()
+
+ # If some params are not passed, we use the default values based on model name.
+ default_params = get_default_params(args.amodel)
+ for name, val in default_params.items():
+ if getattr(args, name) is None:
+ setattr(args, name, val)
+
+ return args
diff --git a/src/laion_clap/training/scheduler.py b/src/laion_clap/training/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0bfdf796c95df003582ede43f0511bd9181c1e4
--- /dev/null
+++ b/src/laion_clap/training/scheduler.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+def assign_learning_rate(optimizer, new_lr):
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = new_lr
+
+
+def _warmup_lr(base_lr, warmup_length, step):
+ return base_lr * (step + 1) / warmup_length
+
+
+def cosine_lr(optimizer, base_lr, warmup_length, steps):
+ def _lr_adjuster(step):
+ if step < warmup_length:
+ lr = _warmup_lr(base_lr, warmup_length, step)
+ else:
+ e = step - warmup_length
+ es = steps - warmup_length
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
+ assign_learning_rate(optimizer, lr)
+ return lr
+ return _lr_adjuster
\ No newline at end of file
diff --git a/src/laion_clap/training/train.py b/src/laion_clap/training/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..06db94158fe15c0a17c95babf60b75d31e8893c7
--- /dev/null
+++ b/src/laion_clap/training/train.py
@@ -0,0 +1,781 @@
+import json
+import logging
+import math
+import os
+import time
+from contextlib import suppress
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from clap_module import ClipLoss, gather_features
+from .distributed import is_master
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def unwrap_model(model):
+ if hasattr(model, "module"):
+ return model.module
+ else:
+ return model
+
+
+def train_one_epoch(
+ model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
+):
+ device = torch.device(args.device)
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ model.train()
+ loss = ClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ mlp_loss=args.clap_mlploss,
+ weight_loss_kappa=args.kappa,
+ )
+
+ dataloader, sampler = data["train"].dataloader, data["train"].sampler
+ if args.distributed and sampler is not None:
+ sampler.set_epoch(epoch)
+ num_batches_per_epoch = dataloader.num_batches
+ sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
+
+ # for toy dataset
+ if args.dataset_type == "toy":
+ dataloader.dataset.generate_queue()
+
+ loss_m = AverageMeter()
+ batch_time_m = AverageMeter()
+ data_time_m = AverageMeter()
+ end = time.time()
+
+ for i, batch in enumerate(dataloader):
+ # logging.info(f"batch {i} of {num_batches_per_epoch}")
+ step = num_batches_per_epoch * epoch + i
+ if isinstance(scheduler, dict):
+ for s in scheduler.values():
+ s(step)
+ else:
+ scheduler(step)
+ audios = batch # contains mel_spec, wavform, and longer list
+ texts = batch['text']
+ # audios = audios.to(device=device, non_blocking=True)
+ # texts = texts.to(device=device, non_blocking=True)
+
+ data_time_m.update(time.time() - end)
+ if isinstance(optimizer, dict):
+ for o_ in optimizer.values():
+ o_.zero_grad()
+ else:
+ optimizer.zero_grad()
+
+ with autocast():
+ (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ logit_scale_a,
+ logit_scale_t,
+ ) = model(audios, texts, device)
+
+ if args.clap_mlploss:
+ total_loss = loss(
+ audio_features=audio_features,
+ text_features=text_features,
+ logit_scale_a=logit_scale_a,
+ logit_scale_t=logit_scale_t,
+ audio_features_mlp=audio_features_mlp,
+ text_features_mlp=text_features_mlp
+ )
+ else:
+ total_loss = loss(
+ audio_features=audio_features,
+ text_features=text_features,
+ logit_scale_a=logit_scale_a
+ )
+ if isinstance(optimizer, dict):
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ for o_ in optimizer.values():
+ if args.horovod:
+ o_.synchronize()
+ scaler.unscale_(o_)
+ with o_.skip_synchronize():
+ scaler.step(o_)
+ else:
+ scaler.step(o_)
+ scaler.update()
+ else:
+ total_loss.backward()
+ for o_ in optimizer.values():
+ o_.step()
+ else:
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ if args.horovod:
+ optimizer.synchronize()
+ scaler.unscale_(optimizer)
+ with optimizer.skip_synchronize():
+ scaler.step(optimizer)
+ else:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ total_loss.backward()
+ optimizer.step()
+
+ # Note: we clamp to 4.6052 = ln(100), as in the original paper.
+ with torch.no_grad():
+ unwrap_model(model).logit_scale_a.clamp_(0, math.log(100))
+ if args.clap_mlploss:
+ unwrap_model(model).logit_scale_t.clamp_(0, math.log(100))
+
+ batch_time_m.update(time.time() - end)
+ end = time.time()
+ batch_count = i + 1
+ if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
+ if isinstance(audios, dict):
+ batch_size = len(audios["waveform"])
+ else:
+ batch_size = len(audios)
+ num_samples = batch_count * batch_size * args.world_size
+ samples_per_epoch = dataloader.num_samples
+ percent_complete = 100.0 * batch_count / num_batches_per_epoch
+
+ # NOTE loss is coarsely sampled, just master node and per log update
+ loss_m.update(total_loss.item(), batch_size)
+ logit_scale_scalar_a = logit_scale_a.item()
+ logit_scale_scalar_t = logit_scale_t.item()
+ if isinstance(optimizer, dict):
+ if args.clap_mlploss:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
+ )
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "scale_text": logit_scale_scalar_t,
+ "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
+ }
+ else:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ )
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
+ }
+
+ else:
+ if args.clap_mlploss:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "scale_text": logit_scale_scalar_t,
+ "lr": optimizer.param_groups[0]["lr"],
+ }
+ else:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "lr": optimizer.param_groups[0]["lr"],
+ }
+ for name, val in log_data.items():
+ name = "train/" + name
+ if tb_writer is not None:
+ tb_writer.add_scalar(name, val, step)
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ wandb.log({name: val, "step": step})
+
+ # resetting batch / data time meters per log window
+ batch_time_m.reset()
+ data_time_m.reset()
+ # end for
+
+
+def evaluate(model, data, epoch, args, tb_writer=None):
+ metrics = {}
+ if not args.parallel_eval:
+ if not is_master(args):
+ return metrics
+ device = torch.device(args.device)
+ model.eval()
+
+ # CHANGE
+ # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
+ # metrics.update(zero_shot_metrics)
+ if is_master(args):
+ print('Evaluating...')
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ if args.val_dataset_names == ['Clotho', 'audiocaps']:
+ # if only clotho and audiocaps are used, then we will use a different evaluation function.
+ # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio.
+ if args.parallel_eval:
+ # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps.
+ raise NotImplementedError("Parallel evaluation not supported for eval only Clotho and audiocaps.")
+ val_metrics_per_dataset = evaluate_clotho_audiocaps(model, data, epoch, args, autocast, device, tb_writer)
+ for m in val_metrics_per_dataset.values():
+ metrics.update(m)
+ if "epoch" not in metrics.keys():
+ metrics.update({"epoch": epoch})
+ metrics = select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args)
+ elif "val" in data and (
+ args.val_frequency
+ and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
+ ):
+ dataloader = data["val"].dataloader
+ num_samples = 0
+ samples_per_val = dataloader.num_samples
+
+ # FIXME this does not scale past small eval datasets
+ # all_audio_features @ all_text_features will blow up memory and compute very quickly
+ eval_info = {}
+ if args.clap_mlploss:
+ eval_info["all"] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ "all_audio_features_mlp": [],
+ "all_text_features_mlp": []
+ } # cumulative_loss = 0.0
+ else:
+ eval_info["all"] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": []
+ } # cumu
+ # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], []
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ audios = batch # contains mel_spec, wavform, and longer list
+ texts = batch['text']
+ # audios = audios.to(device=device, non_blocking=True)
+
+ all_names = list(set(["-".join(b.split("/")[-3:-1]) for b in batch['__url__']]))
+ for name in all_names:
+ if name not in eval_info.keys():
+ if args.clap_mlploss:
+ eval_info[name] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ "all_audio_features_mlp": [],
+ "all_text_features_mlp": [],
+ }
+ else:
+ eval_info[name] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": []
+ }
+ with autocast():
+ (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ logit_scale_a,
+ logit_scale_t,
+ ) = model(audios, texts, device)
+
+ if args.parallel_eval:
+ # multi-GPU eval
+ if args.clap_mlploss:
+ (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ ) = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ audio_features_mlp=audio_features_mlp,
+ text_features_mlp=text_features_mlp,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ mlp_loss=args.clap_mlploss
+ )
+ else:
+ (
+ audio_features,
+ text_features,
+ ) = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ mlp_loss=args.clap_mlploss
+ )
+
+ if is_master(args):
+ num_samples += audio_features.shape[0]
+ for n in [*all_names, "all"]:
+ if n == "all":
+ eval_info[n]["all_audio_features"].append(
+ audio_features.cpu()
+ )
+ eval_info[n]["all_text_features"].append(
+ text_features.cpu()
+ )
+ if args.clap_mlploss:
+ eval_info[n]["all_audio_features_mlp"].append(
+ audio_features_mlp.cpu()
+ )
+ eval_info[n]["all_text_features_mlp"].append(
+ text_features_mlp.cpu()
+ )
+ else:
+ idx = np.where(
+ np.array(
+ ["-".join(b.split("/")[-3:-1]) for b in batch['__url__']]
+ )
+ == n
+ )[0]
+ eval_info[n]["all_audio_features"].append(
+ audio_features.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ eval_info[n]["all_text_features"].append(
+ text_features.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ if args.clap_mlploss:
+ eval_info[n]["all_audio_features_mlp"].append(
+ audio_features_mlp.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ eval_info[n]["all_text_features_mlp"].append(
+ text_features_mlp.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ # print(f'eval step {i}') # (yusong): for debug
+
+ # cumulative_loss += total_loss * batch_size
+ # num_samples += batch_size
+ if is_master(args) and (i % 100) == 0: # and i != 0:
+ logging.info(
+ f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
+ )
+ if is_master(args):
+ val_metrics_per_dataset = {}
+ for n in eval_info.keys():
+ if args.clap_mlploss:
+ metrics_single_dataset = get_metrics(
+ audio_features=torch.cat(eval_info[n]["all_audio_features"]),
+ text_features=torch.cat(eval_info[n]["all_text_features"]),
+ logit_scale_a=logit_scale_a.cpu(),
+ audio_features_mlp=torch.cat(
+ eval_info[n]["all_audio_features_mlp"]
+ ),
+ text_features_mlp=torch.cat(eval_info[n]["all_text_features_mlp"]),
+ logit_scale_t=logit_scale_t.cpu(),
+ mlp_loss=args.clap_mlploss
+ )
+ else:
+ metrics_single_dataset = get_metrics(
+ audio_features=torch.cat(eval_info[n]["all_audio_features"]),
+ text_features=torch.cat(eval_info[n]["all_text_features"]),
+ logit_scale_a=logit_scale_a.cpu(),
+ mlp_loss=args.clap_mlploss
+ )
+ val_metrics_per_dataset[n] = {
+ n + "/" + k: v for k, v in metrics_single_dataset.items()
+ }
+ metrics.update(val_metrics_per_dataset[n])
+ if "epoch" not in metrics.keys():
+ metrics.update({"epoch": epoch})
+ if is_master(args):
+ if not metrics:
+ return metrics
+
+ logging.info(
+ f"Eval Epoch: {epoch} "
+ + "\n".join(
+ [
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()])
+ for m in val_metrics_per_dataset.values()
+ ]
+ )
+ )
+
+ if args.save_logs:
+ for name, val in metrics.items():
+ if tb_writer is not None:
+ tb_writer.add_scalar(f"val/{name}", val, epoch)
+
+ with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
+ f.write(json.dumps(metrics))
+ f.write("\n")
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ for name, val in metrics.items():
+ wandb.log({f"val/{name}": val, "epoch": epoch})
+
+ return metrics
+ else:
+ return metrics
+
+
+def get_metrics(
+ audio_features,
+ text_features,
+ logit_scale_a,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ logit_scale_t=None,
+ mlp_loss=False
+):
+ metrics = {}
+ if mlp_loss:
+ # Set up audio to text & text to audio similary matrice
+ a_logits_per_audio = (
+ (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu()
+ )
+ a_logits_per_text = a_logits_per_audio.t().detach().cpu()
+ t_logits_per_audio = (
+ (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu()
+ )
+ t_logits_per_text = t_logits_per_audio.t().detach().cpu()
+
+ labels = torch.arange(audio_features.shape[0]).long()
+ # Change the loss from two terms into four terms with 2x2 combined CE loss
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels)
+ + F.cross_entropy(a_logits_per_text, labels)
+ + F.cross_entropy(t_logits_per_audio, labels)
+ + F.cross_entropy(t_logits_per_text, labels)
+ ) / 4
+
+ metrics[f"cumulative_loss"] = total_loss.item()
+ metrics[f"num_samples"] = audio_features.shape[0]
+
+ logits = {
+ "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2,
+ "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2,
+ }
+ ground_truth = torch.arange(len(text_features)).view(-1, 1)
+
+ else:
+ # print("text_features", text_features)
+ # print("text_features.shape", text_features.shape)
+ logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ labels = torch.arange(audio_features.shape[0]).long()
+ # Change the loss from two terms into four terms with 2x2 combined CE loss
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels)
+ + F.cross_entropy(logits_per_text, labels)
+ ) / 2
+
+ metrics[f"cumulative_loss"] = total_loss.item()
+ metrics[f"num_samples"] = audio_features.shape[0]
+
+ logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text}
+
+ ground_truth = torch.arange(len(text_features)).view(-1, 1)
+
+ for name, logit in logits.items():
+ ranking = torch.argsort(logit, descending=True)
+ preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread
+ preds = preds.detach().cpu().numpy()
+ metrics[f"{name}_mean_rank"] = preds.mean() + 1
+ metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"{name}_R@{k}"] = np.mean(preds < k)
+ # map@10
+ metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
+
+ return metrics
+
+
+def evaluate_clotho_audiocaps(
+ model, data, epoch, args, autocast, device, tb_writer=None
+):
+ """
+ Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
+ 1. for text-to-audio retrieval, do 5 times and average the results
+ 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
+ 3. for map@10 in audio-to-text retrieval:
+ 3.1: sort the rank of 5 text
+ 3.2: exclude the rank >=10 (0-index)
+ 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
+ (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
+ (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
+ """
+ # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now.
+ dataloader = data["val"].dataloader
+ with torch.no_grad():
+ eval_info = {}
+ for i, batch in enumerate(dataloader):
+ audios = batch # contains mel_spec, wavform, and longer list
+
+ # each item in the list has 5 texts
+ if args.tmodel == "transformer":
+ from clap_module import tokenize
+ texts = [tokenize(t) for t in batch['full_text']]
+ texts = torch.cat(texts)
+ else:
+ from .data import tokenizer
+ texts = [tokenizer(t, tmodel=args.tmodel) for t in batch['full_text']] # 5 texts for each audio
+ texts = {k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()} # 5 x batch
+
+ # audios = audios.to(device=device, non_blocking=True)
+
+ # batch['__url__'] contains the path to the data tar this sample is from
+ # So, b.split("/")[-3:-1] will get you '-'
+ all_names = list(set(["-".join(b.split("/")[-3:-1]) for b in batch['__url__']]))
+ for name in all_names:
+ if name not in eval_info.keys():
+ # we will not use mlp outputs even if args.clap_mlploss=True
+ eval_info[name] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": []
+ }
+ with autocast():
+ audio_features = model(audios, None, device)
+ text_features = model(None, texts, device)
+ audio_features = F.normalize(audio_features, dim=-1)
+ text_features = F.normalize(text_features, dim=-1)
+
+ all_names = list(set(["-".join(b.split("/")[-3:-1]) for b in batch['__url__']]))
+ for n in all_names:
+ idx = np.where(
+ np.array(
+ ["-".join(b.split("/")[-3:-1]) for b in batch['__url__']]
+ )
+ == n
+ )[0]
+ eval_info[n]["all_audio_features"].append(
+ audio_features.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ # (yusong) please double-check. This is for selecting 5 text features at once.
+ # because idx is a list of indices in size of num_samples,
+ # and text_features is a tensor of size (5*num_samples, dim)
+ # so we need to select 5 consecutive indices at once for a single index in idx.
+ eval_info[n]["all_text_features"].append(
+ text_features.cpu().reshape([-1, 5, text_features.shape[1]]).index_select(
+ 0, torch.tensor(idx).long()
+ ).reshape([-1, text_features.shape[1]])
+ )
+
+ val_metrics_all = {}
+
+ for n in eval_info.keys():
+ logit_scale_a, logit_scale_t = model(None, None, device)
+ logit_scale_a = logit_scale_a.cpu()
+
+ audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0)
+ text_features = torch.cat(eval_info[n]["all_text_features"], dim=0)
+
+ logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ # logits_per_audio shape: [num_samples, num_samples*5]
+ # logits_per_text shape: [num_samples*5, num_samples]
+
+ logging.info(f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, "
+ f"logits_per_text shape: {logits_per_text.shape}")
+
+ metrics = {}
+ num_samples = audio_features.shape[0]
+ metrics[f"num_samples"] = num_samples
+
+ # (yusong) the following code is very important, please double-check:
+ # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d]
+ # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
+ # Those two are retrieving one of the 5 text for each audio.
+ labels = torch.arange(audio_features.shape[0]).long()
+ audio_to_text_loss = [
+ F.cross_entropy(
+ logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], labels) for d in range(5)
+ ]
+ text_to_audio_loss = [
+ F.cross_entropy(
+ logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], labels) for d in range(5)
+ ]
+ total_loss = (
+ np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)
+ ) / 2
+
+ metrics[f"cumulative_loss"] = total_loss.item()
+
+ # text to audio: do 5 times
+ pred_text = []
+ for d in range(5):
+ logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
+ ground_truth = torch.arange(len(logit)).view(-1, 1)
+ ranking = torch.argsort(logit, descending=True) # [num_samples, num_samples]
+ preds = torch.where(ranking == ground_truth)[1]
+ pred_text.append(preds.detach().cpu().numpy())
+ pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples]
+ metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1
+ metrics[f"text_to_audio_median_rank"] = np.floor(np.median(pred_text_concat)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k)
+ # map@10
+ metrics[f"text_to_audio_mAP@10"] = np.mean(np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0))
+
+ # audio to text: take the best result
+ # for audio to text map 10, sort and assign descending ground truth.
+ # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103
+ # map@10
+ map_all = []
+ pred_audio_all = []
+ for d in range(num_samples):
+ # logits_per_audio: [num_samples, num_samples*5]
+ logit_single = logits_per_audio[d, :] # [5*num_samples]
+ # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4]
+ ranking = torch.argsort(logit_single, descending=True) # [5*num_samples]
+ # ranking: the index of first match, second match, ...
+ ground_truth = torch.arange(d * 5, d * 5 + 5)[None]
+ all_pred = torch.where(torch.stack([ranking] * 5) == ground_truth.view(-1, 1))[1]
+ min_pred = torch.min(all_pred)
+ pred_audio_all.append(min_pred.detach().cpu().numpy())
+ all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy()
+ # /5 because we have 5 text, so it means for the text rank >=10 we count as 0.
+ map_single = np.sum((np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))) / 5
+ map_all.append(map_single)
+ metrics[f"audio_to_text_mAP@10"] = np.mean(map_all)
+ for k in [1, 5, 10]:
+ metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k)
+
+ val_metrics_all[n] = {
+ n + "/" + k: v for k, v in metrics.items()
+ }
+ return val_metrics_all
+
+
+def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset):
+ """
+ Calculate performance for Clotho+AudioCaps for model selection.
+ """
+ selection_performance_all = []
+ for n in val_metrics_per_dataset.keys():
+ selection_performance = (val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] +
+ val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]) / 2
+ selection_performance_all.append(selection_performance)
+ return np.mean(selection_performance_all)
+
+
+def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args):
+ # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value
+ # metrics: dict, key: metric name, value: metric value
+ # Hack: use args to save the top performance
+ if not hasattr(args, "top_selection_performance"):
+ selection_performance = calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset)
+ # TODO: write the if and else together
+ metric_update = {}
+ for n in val_metrics_per_dataset.keys():
+ for k in val_metrics_per_dataset[n].keys():
+ metric_update[k.split('/')[0] + '-top' + '/' + k.split('/')[1]] = val_metrics_per_dataset[n][k]
+ metric_update['top_selection_performance'] = selection_performance
+ metric_update['top-selection-epoch'] = metrics['epoch']
+ metrics.update(metric_update)
+ args.top_metric = metric_update
+ args.top_selection_performance = selection_performance
+ else:
+ selection_performance_new = calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset)
+ selection_performance_old = args.top_selection_performance
+ if selection_performance_new > selection_performance_old:
+ metric_update = {}
+ for n in val_metrics_per_dataset.keys():
+ for k in val_metrics_per_dataset[n].keys():
+ metric_update[k.split('/')[0] + '-top' + '/' + k.split('/')[1]] = val_metrics_per_dataset[n][k]
+ metric_update['top_selection_performance'] = selection_performance_new
+ metric_update['top-selection-epoch'] = metrics['epoch']
+ metrics.update(metric_update)
+ args.top_metric = metric_update
+ args.top_selection_performance = selection_performance_new
+ else:
+ metrics.update(args.top_metric)
+ return metrics
diff --git a/src/laion_clap/training/zero_shot.py b/src/laion_clap/training/zero_shot.py
new file mode 100644
index 0000000000000000000000000000000000000000..04472c16e36041f90c8f229c5e026dcc394fb977
--- /dev/null
+++ b/src/laion_clap/training/zero_shot.py
@@ -0,0 +1,90 @@
+# NOTE: This script is currently not supported for CLAP.
+import logging
+from contextlib import suppress
+
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+from clap_module import tokenize
+from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
+
+
+def zero_shot_classifier(model, classnames, templates, args):
+ with torch.no_grad():
+ zeroshot_weights = []
+ for classname in tqdm(classnames):
+ texts = [template(classname) for template in templates] # format with class
+ texts = tokenize(texts).to(args.device) # tokenize
+ if args.distributed and not args.horovod:
+ class_embeddings = model.module.encode_text(texts)
+ else:
+ class_embeddings = model.encode_text(texts)
+ class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
+ class_embedding /= class_embedding.norm()
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
+ return zeroshot_weights
+
+
+def accuracy(output, target, topk=(1,)):
+ pred = output.topk(max(topk), 1, True, True)[1].t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+ return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
+
+
+def run(model, classifier, dataloader, args):
+ autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
+ with torch.no_grad():
+ top1, top5, n = 0., 0., 0.
+ for images, target in tqdm(dataloader, unit_scale=args.batch_size):
+ images = images.to(args.device)
+ target = target.to(args.device)
+
+ with autocast():
+ # predict
+ if args.distributed and not args.horovod:
+ image_features = model.module.encode_image(images)
+ else:
+ image_features = model.encode_image(images)
+ image_features = F.normalize(image_features, dim=-1)
+ logits = 100. * image_features @ classifier
+
+ # measure accuracy
+ acc1, acc5 = accuracy(logits, target, topk=(1, 5))
+ top1 += acc1
+ top5 += acc5
+ n += images.size(0)
+
+ top1 = (top1 / n)
+ top5 = (top5 / n)
+ return top1, top5
+
+
+def zero_shot_eval(model, data, epoch, args):
+ if 'imagenet-val' not in data and 'imagenet-v2' not in data:
+ return {}
+ if args.zeroshot_frequency == 0:
+ return {}
+ if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
+ return {}
+
+ logging.info('Starting zero-shot imagenet.')
+
+ logging.info('Building zero-shot classifier')
+ classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args)
+
+ logging.info('Using classifier')
+ results = {}
+ if 'imagenet-val' in data:
+ top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args)
+ results['imagenet-zeroshot-val-top1'] = top1
+ results['imagenet-zeroshot-val-top5'] = top5
+ if 'imagenet-v2' in data:
+ top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args)
+ results['imagenetv2-zeroshot-val-top1'] = top1
+ results['imagenetv2-zeroshot-val-top5'] = top5
+
+ logging.info('Finished zero-shot imagenet.')
+
+ return results
diff --git a/src/laion_clap/unit_test.py b/src/laion_clap/unit_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f138c627b43bb70b45bdd9364841ac0f113e32dc
--- /dev/null
+++ b/src/laion_clap/unit_test.py
@@ -0,0 +1,75 @@
+"""
+Contrastive Language-Audio Pretraining Model from LAION
+--------------------------------------------------------
+Paper: https://arxiv.org/abs/2211.06687
+Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui
+Support: LAION
+"""
+
+import numpy as np
+import librosa
+import torch
+import laion_clap
+
+# quantization
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1., a_max=1.)
+ return (x * 32767.).astype(np.int16)
+
+model = laion_clap.CLAP_Module(enable_fusion=False)
+model.load_ckpt()
+
+# Directly get audio embeddings from audio files
+audio_file = [
+ '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav',
+ '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav'
+]
+audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=False)
+print(audio_embed[:,-20:])
+print(audio_embed.shape)
+
+# Get audio embeddings from audio data
+audio_data, _ = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) # sample rate should be 48000
+audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T)
+audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False)
+print(audio_embed[:,-20:])
+print(audio_embed.shape)
+
+# Directly get audio embeddings from audio files, but return torch tensor
+audio_file = [
+ '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav',
+ '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav'
+]
+audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=True)
+print(audio_embed[:,-20:])
+print(audio_embed.shape)
+
+# Get audio embeddings from audio data
+audio_data, _ = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) # sample rate should be 48000
+audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T)
+audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model
+audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=True)
+print(audio_embed[:,-20:])
+print(audio_embed.shape)
+
+# Get text embedings from texts:
+text_data = ["I love the contrastive learning", "I love the pretrain model"]
+text_embed = model.get_text_embedding(text_data)
+print(text_embed)
+print(text_embed.shape)
+
+# Get text embedings from texts, but return torch tensor:
+text_data = ["I love the contrastive learning", "I love the pretrain model"]
+text_embed = model.get_text_embedding(text_data, use_tensor=True)
+print(text_embed)
+print(text_embed.shape)
+
+
+
+
+
+
diff --git a/src/tests/__init__.py b/src/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/tests/check_ckpt.py b/src/tests/check_ckpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d034dd44ac67643f5f2b6ed8d229ee20d1ccbde4
--- /dev/null
+++ b/src/tests/check_ckpt.py
@@ -0,0 +1,802 @@
+import torch
+
+def keys_in_state_dict(ckpt, device='cpu'):
+ if device=="cpu":
+ a = torch.load(ckpt, map_location=torch.device('cpu'))["state_dict"]
+ else:
+ a = torch.load(ckpt)["state_dict"]
+ print("keys_in_state_dict", a.keys())
+
+
+def check_ckpt_diff(ckpt_a, ckpt_b, key_include=None, key_exclude=None, device='cpu', verbose=True):
+ if device=="cpu":
+ a = torch.load(ckpt_a, map_location=torch.device('cpu'))["state_dict"]
+ b = torch.load(ckpt_b, map_location=torch.device('cpu'))["state_dict"]
+ else:
+ a = torch.load(ckpt_a)["state_dict"]
+ b = torch.load(ckpt_b)["state_dict"]
+ a_sum = 0
+ b_sum = 0
+ difference_count = 0
+ for k in a.keys():
+ if key_include is not None and key_include not in k:
+ continue
+ if key_exclude is not None and key_exclude in k:
+ continue
+ if k in b.keys():
+ a_sum += torch.sum(a[k])
+ b_sum += torch.sum(b[k])
+ if verbose:
+ if torch.sum(a[k]) != torch.sum(b[k]):
+ print(f"key {k} is different")
+ difference_count += 1
+ print("a_sum: ", a_sum)
+ print("b_sum: ", b_sum)
+ print("diff: ", a_sum - b_sum)
+ if verbose:
+ print("difference_count: ", difference_count)
+ return bool(a_sum - b_sum)
+
+# Transformer no freeze:
+# check_ckpt_diff("/fsx/clap_logs/2022_09_11-19_37_08-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_11-19_37_08-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.resblocks")
+
+check_ckpt_diff("/fsx/clap_logs/2022_09_29-23_42_40-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_1.pt",
+ "/fsx/clap_logs/2022_09_29-23_42_40-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_2.pt",
+ "text_branch.resblocks")
+
+# key module.text_branch.resblocks.0.attn.in_proj_weight is different
+# key module.text_branch.resblocks.0.attn.in_proj_bias is different
+# key module.text_branch.resblocks.0.attn.out_proj.weight is different
+# key module.text_branch.resblocks.0.attn.out_proj.bias is different
+# key module.text_branch.resblocks.0.ln_1.weight is different
+# key module.text_branch.resblocks.0.ln_1.bias is different
+# key module.text_branch.resblocks.0.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.0.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.0.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.0.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.0.ln_2.weight is different
+# key module.text_branch.resblocks.0.ln_2.bias is different
+# key module.text_branch.resblocks.1.attn.in_proj_weight is different
+# key module.text_branch.resblocks.1.attn.in_proj_bias is different
+# key module.text_branch.resblocks.1.attn.out_proj.weight is different
+# key module.text_branch.resblocks.1.attn.out_proj.bias is different
+# key module.text_branch.resblocks.1.ln_1.weight is different
+# key module.text_branch.resblocks.1.ln_1.bias is different
+# key module.text_branch.resblocks.1.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.1.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.1.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.1.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.1.ln_2.weight is different
+# key module.text_branch.resblocks.1.ln_2.bias is different
+# key module.text_branch.resblocks.2.attn.in_proj_weight is different
+# key module.text_branch.resblocks.2.attn.in_proj_bias is different
+# key module.text_branch.resblocks.2.attn.out_proj.weight is different
+# key module.text_branch.resblocks.2.attn.out_proj.bias is different
+# key module.text_branch.resblocks.2.ln_1.weight is different
+# key module.text_branch.resblocks.2.ln_1.bias is different
+# key module.text_branch.resblocks.2.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.2.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.2.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.2.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.2.ln_2.weight is different
+# key module.text_branch.resblocks.2.ln_2.bias is different
+# key module.text_branch.resblocks.3.attn.in_proj_weight is different
+# key module.text_branch.resblocks.3.attn.in_proj_bias is different
+# key module.text_branch.resblocks.3.attn.out_proj.weight is different
+# key module.text_branch.resblocks.3.attn.out_proj.bias is different
+# key module.text_branch.resblocks.3.ln_1.weight is different
+# key module.text_branch.resblocks.3.ln_1.bias is different
+# key module.text_branch.resblocks.3.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.3.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.3.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.3.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.3.ln_2.weight is different
+# key module.text_branch.resblocks.3.ln_2.bias is different
+# key module.text_branch.resblocks.4.attn.in_proj_weight is different
+# key module.text_branch.resblocks.4.attn.in_proj_bias is different
+# key module.text_branch.resblocks.4.attn.out_proj.weight is different
+# key module.text_branch.resblocks.4.attn.out_proj.bias is different
+# key module.text_branch.resblocks.4.ln_1.weight is different
+# key module.text_branch.resblocks.4.ln_1.bias is different
+# key module.text_branch.resblocks.4.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.4.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.4.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.4.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.4.ln_2.weight is different
+# key module.text_branch.resblocks.4.ln_2.bias is different
+# key module.text_branch.resblocks.5.attn.in_proj_weight is different
+# key module.text_branch.resblocks.5.attn.in_proj_bias is different
+# key module.text_branch.resblocks.5.attn.out_proj.weight is different
+# key module.text_branch.resblocks.5.attn.out_proj.bias is different
+# key module.text_branch.resblocks.5.ln_1.weight is different
+# key module.text_branch.resblocks.5.ln_1.bias is different
+# key module.text_branch.resblocks.5.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.5.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.5.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.5.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.5.ln_2.weight is different
+# key module.text_branch.resblocks.5.ln_2.bias is different
+# key module.text_branch.resblocks.6.attn.in_proj_weight is different
+# key module.text_branch.resblocks.6.attn.in_proj_bias is different
+# key module.text_branch.resblocks.6.attn.out_proj.weight is different
+# key module.text_branch.resblocks.6.attn.out_proj.bias is different
+# key module.text_branch.resblocks.6.ln_1.weight is different
+# key module.text_branch.resblocks.6.ln_1.bias is different
+# key module.text_branch.resblocks.6.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.6.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.6.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.6.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.6.ln_2.weight is different
+# key module.text_branch.resblocks.6.ln_2.bias is different
+# key module.text_branch.resblocks.7.attn.in_proj_weight is different
+# key module.text_branch.resblocks.7.attn.in_proj_bias is different
+# key module.text_branch.resblocks.7.attn.out_proj.weight is different
+# key module.text_branch.resblocks.7.attn.out_proj.bias is different
+# key module.text_branch.resblocks.7.ln_1.weight is different
+# key module.text_branch.resblocks.7.ln_1.bias is different
+# key module.text_branch.resblocks.7.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.7.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.7.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.7.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.7.ln_2.weight is different
+# key module.text_branch.resblocks.7.ln_2.bias is different
+# key module.text_branch.resblocks.8.attn.in_proj_weight is different
+# key module.text_branch.resblocks.8.attn.in_proj_bias is different
+# key module.text_branch.resblocks.8.attn.out_proj.weight is different
+# key module.text_branch.resblocks.8.attn.out_proj.bias is different
+# key module.text_branch.resblocks.8.ln_1.weight is different
+# key module.text_branch.resblocks.8.ln_1.bias is different
+# key module.text_branch.resblocks.8.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.8.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.8.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.8.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.8.ln_2.weight is different
+# key module.text_branch.resblocks.8.ln_2.bias is different
+# key module.text_branch.resblocks.9.attn.in_proj_weight is different
+# key module.text_branch.resblocks.9.attn.in_proj_bias is different
+# key module.text_branch.resblocks.9.attn.out_proj.weight is different
+# key module.text_branch.resblocks.9.attn.out_proj.bias is different
+# key module.text_branch.resblocks.9.ln_1.weight is different
+# key module.text_branch.resblocks.9.ln_1.bias is different
+# key module.text_branch.resblocks.9.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.9.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.9.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.9.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.9.ln_2.weight is different
+# key module.text_branch.resblocks.9.ln_2.bias is different
+# key module.text_branch.resblocks.10.attn.in_proj_weight is different
+# key module.text_branch.resblocks.10.attn.in_proj_bias is different
+# key module.text_branch.resblocks.10.attn.out_proj.weight is different
+# key module.text_branch.resblocks.10.attn.out_proj.bias is different
+# key module.text_branch.resblocks.10.ln_1.weight is different
+# key module.text_branch.resblocks.10.ln_1.bias is different
+# key module.text_branch.resblocks.10.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.10.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.10.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.10.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.10.ln_2.weight is different
+# key module.text_branch.resblocks.10.ln_2.bias is different
+# key module.text_branch.resblocks.11.attn.in_proj_weight is different
+# key module.text_branch.resblocks.11.attn.in_proj_bias is different
+# key module.text_branch.resblocks.11.attn.out_proj.weight is different
+# key module.text_branch.resblocks.11.attn.out_proj.bias is different
+# key module.text_branch.resblocks.11.ln_1.weight is different
+# key module.text_branch.resblocks.11.ln_1.bias is different
+# key module.text_branch.resblocks.11.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.11.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.11.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.11.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.11.ln_2.weight is different
+# key module.text_branch.resblocks.11.ln_2.bias is different
+# a_sum: tensor(12113.6445)
+# b_sum: tensor(9883.4424)
+# diff: tensor(2230.2021)
+# True
+
+
+# Transformer freeze:
+# check_ckpt_diff("/fsx/clap_logs/2022_09_16-18_55_10-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_16-18_55_10-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.resblocks")
+
+# key module.text_branch.resblocks.0.attn.in_proj_weight is different
+# key module.text_branch.resblocks.0.attn.in_proj_bias is different
+# key module.text_branch.resblocks.0.attn.out_proj.weight is different
+# key module.text_branch.resblocks.0.attn.out_proj.bias is different
+# key module.text_branch.resblocks.0.ln_1.weight is different
+# key module.text_branch.resblocks.0.ln_1.bias is different
+# key module.text_branch.resblocks.0.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.0.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.0.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.0.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.0.ln_2.weight is different
+# key module.text_branch.resblocks.0.ln_2.bias is different
+# key module.text_branch.resblocks.1.attn.in_proj_weight is different
+# key module.text_branch.resblocks.1.attn.in_proj_bias is different
+# key module.text_branch.resblocks.1.attn.out_proj.weight is different
+# key module.text_branch.resblocks.1.attn.out_proj.bias is different
+# key module.text_branch.resblocks.1.ln_1.weight is different
+# key module.text_branch.resblocks.1.ln_1.bias is different
+# key module.text_branch.resblocks.1.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.1.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.1.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.1.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.1.ln_2.weight is different
+# key module.text_branch.resblocks.1.ln_2.bias is different
+# key module.text_branch.resblocks.2.attn.in_proj_weight is different
+# key module.text_branch.resblocks.2.attn.in_proj_bias is different
+# key module.text_branch.resblocks.2.attn.out_proj.weight is different
+# key module.text_branch.resblocks.2.attn.out_proj.bias is different
+# key module.text_branch.resblocks.2.ln_1.weight is different
+# key module.text_branch.resblocks.2.ln_1.bias is different
+# key module.text_branch.resblocks.2.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.2.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.2.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.2.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.2.ln_2.weight is different
+# key module.text_branch.resblocks.2.ln_2.bias is different
+# key module.text_branch.resblocks.3.attn.in_proj_weight is different
+# key module.text_branch.resblocks.3.attn.in_proj_bias is different
+# key module.text_branch.resblocks.3.attn.out_proj.weight is different
+# key module.text_branch.resblocks.3.attn.out_proj.bias is different
+# key module.text_branch.resblocks.3.ln_1.weight is different
+# key module.text_branch.resblocks.3.ln_1.bias is different
+# key module.text_branch.resblocks.3.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.3.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.3.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.3.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.3.ln_2.weight is different
+# key module.text_branch.resblocks.3.ln_2.bias is different
+# key module.text_branch.resblocks.4.attn.in_proj_weight is different
+# key module.text_branch.resblocks.4.attn.in_proj_bias is different
+# key module.text_branch.resblocks.4.attn.out_proj.weight is different
+# key module.text_branch.resblocks.4.attn.out_proj.bias is different
+# key module.text_branch.resblocks.4.ln_1.weight is different
+# key module.text_branch.resblocks.4.ln_1.bias is different
+# key module.text_branch.resblocks.4.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.4.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.4.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.4.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.4.ln_2.weight is different
+# key module.text_branch.resblocks.4.ln_2.bias is different
+# key module.text_branch.resblocks.5.attn.in_proj_weight is different
+# key module.text_branch.resblocks.5.attn.in_proj_bias is different
+# key module.text_branch.resblocks.5.attn.out_proj.weight is different
+# key module.text_branch.resblocks.5.attn.out_proj.bias is different
+# key module.text_branch.resblocks.5.ln_1.weight is different
+# key module.text_branch.resblocks.5.ln_1.bias is different
+# key module.text_branch.resblocks.5.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.5.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.5.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.5.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.5.ln_2.weight is different
+# key module.text_branch.resblocks.5.ln_2.bias is different
+# key module.text_branch.resblocks.6.attn.in_proj_weight is different
+# key module.text_branch.resblocks.6.attn.in_proj_bias is different
+# key module.text_branch.resblocks.6.attn.out_proj.weight is different
+# key module.text_branch.resblocks.6.attn.out_proj.bias is different
+# key module.text_branch.resblocks.6.ln_1.weight is different
+# key module.text_branch.resblocks.6.ln_1.bias is different
+# key module.text_branch.resblocks.6.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.6.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.6.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.6.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.6.ln_2.weight is different
+# key module.text_branch.resblocks.6.ln_2.bias is different
+# key module.text_branch.resblocks.7.attn.in_proj_weight is different
+# key module.text_branch.resblocks.7.attn.in_proj_bias is different
+# key module.text_branch.resblocks.7.attn.out_proj.weight is different
+# key module.text_branch.resblocks.7.attn.out_proj.bias is different
+# key module.text_branch.resblocks.7.ln_1.weight is different
+# key module.text_branch.resblocks.7.ln_1.bias is different
+# key module.text_branch.resblocks.7.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.7.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.7.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.7.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.7.ln_2.weight is different
+# key module.text_branch.resblocks.7.ln_2.bias is different
+# key module.text_branch.resblocks.8.attn.in_proj_weight is different
+# key module.text_branch.resblocks.8.attn.in_proj_bias is different
+# key module.text_branch.resblocks.8.attn.out_proj.weight is different
+# key module.text_branch.resblocks.8.attn.out_proj.bias is different
+# key module.text_branch.resblocks.8.ln_1.weight is different
+# key module.text_branch.resblocks.8.ln_1.bias is different
+# key module.text_branch.resblocks.8.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.8.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.8.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.8.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.8.ln_2.weight is different
+# key module.text_branch.resblocks.8.ln_2.bias is different
+# key module.text_branch.resblocks.9.attn.in_proj_weight is different
+# key module.text_branch.resblocks.9.attn.in_proj_bias is different
+# key module.text_branch.resblocks.9.attn.out_proj.weight is different
+# key module.text_branch.resblocks.9.attn.out_proj.bias is different
+# key module.text_branch.resblocks.9.ln_1.weight is different
+# key module.text_branch.resblocks.9.ln_1.bias is different
+# key module.text_branch.resblocks.9.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.9.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.9.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.9.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.9.ln_2.weight is different
+# key module.text_branch.resblocks.9.ln_2.bias is different
+# key module.text_branch.resblocks.10.attn.in_proj_weight is different
+# key module.text_branch.resblocks.10.attn.in_proj_bias is different
+# key module.text_branch.resblocks.10.attn.out_proj.weight is different
+# key module.text_branch.resblocks.10.attn.out_proj.bias is different
+# key module.text_branch.resblocks.10.ln_1.weight is different
+# key module.text_branch.resblocks.10.ln_1.bias is different
+# key module.text_branch.resblocks.10.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.10.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.10.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.10.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.10.ln_2.weight is different
+# key module.text_branch.resblocks.10.ln_2.bias is different
+# key module.text_branch.resblocks.11.attn.in_proj_weight is different
+# key module.text_branch.resblocks.11.attn.in_proj_bias is different
+# key module.text_branch.resblocks.11.attn.out_proj.weight is different
+# key module.text_branch.resblocks.11.attn.out_proj.bias is different
+# key module.text_branch.resblocks.11.ln_1.weight is different
+# key module.text_branch.resblocks.11.ln_1.bias is different
+# key module.text_branch.resblocks.11.mlp.c_fc.weight is different
+# key module.text_branch.resblocks.11.mlp.c_fc.bias is different
+# key module.text_branch.resblocks.11.mlp.c_proj.weight is different
+# key module.text_branch.resblocks.11.mlp.c_proj.bias is different
+# key module.text_branch.resblocks.11.ln_2.weight is different
+# key module.text_branch.resblocks.11.ln_2.bias is different
+# a_sum: tensor(12133.6348)
+# b_sum: tensor(10423.9521)
+# diff: tensor(1709.6826)
+# True
+
+
+# bert no freeze:
+# check_ckpt_diff("/fsx/clap_logs/2022_09_14-02_33_11-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_14-02_33_11-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.encoder")
+
+# key module.text_branch.encoder.layer.0.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.0.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.0.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.0.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.0.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.0.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.0.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.0.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.0.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.0.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.0.output.dense.weight is different
+# key module.text_branch.encoder.layer.0.output.dense.bias is different
+# key module.text_branch.encoder.layer.0.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.0.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.1.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.1.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.1.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.1.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.1.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.1.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.1.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.1.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.1.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.1.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.1.output.dense.weight is different
+# key module.text_branch.encoder.layer.1.output.dense.bias is different
+# key module.text_branch.encoder.layer.1.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.1.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.2.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.2.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.2.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.2.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.2.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.2.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.2.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.2.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.2.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.2.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.2.output.dense.weight is different
+# key module.text_branch.encoder.layer.2.output.dense.bias is different
+# key module.text_branch.encoder.layer.2.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.2.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.3.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.3.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.3.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.3.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.3.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.3.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.3.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.3.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.3.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.3.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.3.output.dense.weight is different
+# key module.text_branch.encoder.layer.3.output.dense.bias is different
+# key module.text_branch.encoder.layer.3.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.3.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.4.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.4.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.4.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.4.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.4.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.4.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.4.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.4.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.4.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.4.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.4.output.dense.weight is different
+# key module.text_branch.encoder.layer.4.output.dense.bias is different
+# key module.text_branch.encoder.layer.4.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.4.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.5.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.5.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.5.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.5.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.5.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.5.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.5.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.5.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.5.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.5.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.5.output.dense.weight is different
+# key module.text_branch.encoder.layer.5.output.dense.bias is different
+# key module.text_branch.encoder.layer.5.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.5.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.6.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.6.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.6.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.6.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.6.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.6.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.6.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.6.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.6.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.6.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.6.output.dense.weight is different
+# key module.text_branch.encoder.layer.6.output.dense.bias is different
+# key module.text_branch.encoder.layer.6.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.6.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.7.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.7.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.7.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.7.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.7.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.7.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.7.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.7.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.7.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.7.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.7.output.dense.weight is different
+# key module.text_branch.encoder.layer.7.output.dense.bias is different
+# key module.text_branch.encoder.layer.7.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.7.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.8.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.8.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.8.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.8.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.8.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.8.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.8.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.8.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.8.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.8.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.8.output.dense.weight is different
+# key module.text_branch.encoder.layer.8.output.dense.bias is different
+# key module.text_branch.encoder.layer.8.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.8.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.9.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.9.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.9.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.9.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.9.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.9.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.9.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.9.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.9.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.9.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.9.output.dense.weight is different
+# key module.text_branch.encoder.layer.9.output.dense.bias is different
+# key module.text_branch.encoder.layer.9.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.9.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.10.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.10.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.10.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.10.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.10.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.10.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.10.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.10.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.10.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.10.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.10.output.dense.weight is different
+# key module.text_branch.encoder.layer.10.output.dense.bias is different
+# key module.text_branch.encoder.layer.10.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.10.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.11.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.11.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.11.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.11.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.11.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.11.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.11.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.11.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.11.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.11.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.11.output.dense.weight is different
+# key module.text_branch.encoder.layer.11.output.dense.bias is different
+# key module.text_branch.encoder.layer.11.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.11.output.LayerNorm.bias is different
+# a_sum: tensor(15185.1230)
+# b_sum: tensor(15576.5596)
+# diff: tensor(-391.4365)
+# True
+
+
+# bert freeze:
+# check_ckpt_diff("/fsx/clap_logs/2022_09_13-01_25_15-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_13-01_25_15-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.encoder")
+
+# key module.text_branch.encoder.layer.0.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.0.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.0.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.0.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.0.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.0.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.0.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.0.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.0.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.0.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.0.output.dense.weight is different
+# key module.text_branch.encoder.layer.0.output.dense.bias is different
+# key module.text_branch.encoder.layer.0.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.0.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.1.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.1.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.1.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.1.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.1.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.1.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.1.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.1.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.1.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.1.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.1.output.dense.weight is different
+# key module.text_branch.encoder.layer.1.output.dense.bias is different
+# key module.text_branch.encoder.layer.1.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.1.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.2.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.2.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.2.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.2.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.2.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.2.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.2.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.2.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.2.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.2.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.2.output.dense.weight is different
+# key module.text_branch.encoder.layer.2.output.dense.bias is different
+# key module.text_branch.encoder.layer.2.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.2.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.3.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.3.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.3.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.3.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.3.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.3.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.3.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.3.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.3.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.3.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.3.output.dense.weight is different
+# key module.text_branch.encoder.layer.3.output.dense.bias is different
+# key module.text_branch.encoder.layer.3.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.3.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.4.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.4.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.4.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.4.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.4.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.4.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.4.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.4.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.4.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.4.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.4.output.dense.weight is different
+# key module.text_branch.encoder.layer.4.output.dense.bias is different
+# key module.text_branch.encoder.layer.4.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.4.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.5.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.5.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.5.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.5.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.5.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.5.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.5.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.5.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.5.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.5.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.5.output.dense.weight is different
+# key module.text_branch.encoder.layer.5.output.dense.bias is different
+# key module.text_branch.encoder.layer.5.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.5.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.6.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.6.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.6.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.6.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.6.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.6.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.6.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.6.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.6.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.6.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.6.output.dense.weight is different
+# key module.text_branch.encoder.layer.6.output.dense.bias is different
+# key module.text_branch.encoder.layer.6.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.6.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.7.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.7.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.7.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.7.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.7.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.7.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.7.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.7.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.7.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.7.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.7.output.dense.weight is different
+# key module.text_branch.encoder.layer.7.output.dense.bias is different
+# key module.text_branch.encoder.layer.7.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.7.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.8.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.8.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.8.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.8.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.8.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.8.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.8.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.8.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.8.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.8.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.8.output.dense.weight is different
+# key module.text_branch.encoder.layer.8.output.dense.bias is different
+# key module.text_branch.encoder.layer.8.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.8.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.9.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.9.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.9.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.9.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.9.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.9.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.9.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.9.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.9.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.9.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.9.output.dense.weight is different
+# key module.text_branch.encoder.layer.9.output.dense.bias is different
+# key module.text_branch.encoder.layer.9.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.9.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.10.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.10.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.10.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.10.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.10.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.10.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.10.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.10.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.10.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.10.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.10.output.dense.weight is different
+# key module.text_branch.encoder.layer.10.output.dense.bias is different
+# key module.text_branch.encoder.layer.10.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.10.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.11.attention.self.query.weight is different
+# key module.text_branch.encoder.layer.11.attention.self.query.bias is different
+# key module.text_branch.encoder.layer.11.attention.self.key.weight is different
+# key module.text_branch.encoder.layer.11.attention.self.key.bias is different
+# key module.text_branch.encoder.layer.11.attention.self.value.weight is different
+# key module.text_branch.encoder.layer.11.attention.self.value.bias is different
+# key module.text_branch.encoder.layer.11.attention.output.dense.weight is different
+# key module.text_branch.encoder.layer.11.attention.output.dense.bias is different
+# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.bias is different
+# key module.text_branch.encoder.layer.11.intermediate.dense.weight is different
+# key module.text_branch.encoder.layer.11.intermediate.dense.bias is different
+# key module.text_branch.encoder.layer.11.output.dense.weight is different
+# key module.text_branch.encoder.layer.11.output.dense.bias is different
+# key module.text_branch.encoder.layer.11.output.LayerNorm.weight is different
+# key module.text_branch.encoder.layer.11.output.LayerNorm.bias is different
+# a_sum: tensor(15078.6641)
+# b_sum: tensor(15540.0723)
+# diff: tensor(-461.4082)
+# True
+
+# linear_prob_text
+# check_ckpt_diff("/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_50.pt", "/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_100.pt", "text_branch.resblocks")
+
+# a_sum: tensor(12111.0244)
+# b_sum: tensor(12111.0244)
+# diff: tensor(0.)
+
+# linear_prob_audio
+# check_ckpt_diff("/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_50.pt", "/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_100.pt", "clap_model")
+
+# key clap_model.audio_branch.bn0.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block1.bn1.running_mean is different
+# key clap_model.audio_branch.conv_block1.bn1.running_var is different
+# key clap_model.audio_branch.conv_block1.bn1.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block1.bn2.running_mean is different
+# key clap_model.audio_branch.conv_block1.bn2.running_var is different
+# key clap_model.audio_branch.conv_block1.bn2.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block2.bn1.running_mean is different
+# key clap_model.audio_branch.conv_block2.bn1.running_var is different
+# key clap_model.audio_branch.conv_block2.bn1.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block2.bn2.running_mean is different
+# key clap_model.audio_branch.conv_block2.bn2.running_var is different
+# key clap_model.audio_branch.conv_block2.bn2.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block3.bn1.running_mean is different
+# key clap_model.audio_branch.conv_block3.bn1.running_var is different
+# key clap_model.audio_branch.conv_block3.bn1.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block3.bn2.running_mean is different
+# key clap_model.audio_branch.conv_block3.bn2.running_var is different
+# key clap_model.audio_branch.conv_block3.bn2.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block4.bn1.running_mean is different
+# key clap_model.audio_branch.conv_block4.bn1.running_var is different
+# key clap_model.audio_branch.conv_block4.bn1.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block4.bn2.running_mean is different
+# key clap_model.audio_branch.conv_block4.bn2.running_var is different
+# key clap_model.audio_branch.conv_block4.bn2.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block5.bn1.running_mean is different
+# key clap_model.audio_branch.conv_block5.bn1.running_var is different
+# key clap_model.audio_branch.conv_block5.bn1.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block5.bn2.running_mean is different
+# key clap_model.audio_branch.conv_block5.bn2.running_var is different
+# key clap_model.audio_branch.conv_block5.bn2.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block6.bn1.running_mean is different
+# key clap_model.audio_branch.conv_block6.bn1.running_var is different
+# key clap_model.audio_branch.conv_block6.bn1.num_batches_tracked is different
+# key clap_model.audio_branch.conv_block6.bn2.running_mean is different
+# key clap_model.audio_branch.conv_block6.bn2.running_var is different
+# key clap_model.audio_branch.conv_block6.bn2.num_batches_tracked is different
+# a_sum: tensor(120061.5078)
+# b_sum: tensor(122656.0469)
+# diff: tensor(-2594.5391)
+# True
+
diff --git a/src/tests/check_tars.py b/src/tests/check_tars.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dcf1c120dcc316f4f3166f57d448e64eaf2dbdd
--- /dev/null
+++ b/src/tests/check_tars.py
@@ -0,0 +1,120 @@
+import webdataset as wds
+import soundfile as sf
+import io
+import os
+import random
+import copy
+from tqdm import tqdm
+import shutil
+import argparse
+import traceback
+import logging
+import json
+from laion_clap import tokenize
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--tar-path",
+ type=str,
+ default=None,
+ help="Path to the tars",
+ )
+ parser.add_argument(
+ "--start",
+ type=int,
+ default=0,
+ help="start from tar-path + start",
+ )
+ parser.add_argument(
+ "--end",
+ type=int,
+ default=99999,
+ help="end with tar-path + end",
+ )
+ parser.add_argument(
+ "--exclude",
+ nargs='+',
+ default=None,
+ help="exclude tar-path + exclude",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=1,
+ )
+ parser.add_argument(
+ "--order",
+ default=False,
+ action='store_true',
+ help="if keep the search order accendingly",
+ )
+ args = parser.parse_args()
+ return args
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+ return True
+
+def preprocess(
+ sample,
+):
+ """
+ Preprocess a single sample for wdsdataloader.
+ """
+ audio_ext = "flac"
+ text_ext = "json"
+ audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+ json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
+ sample["waveform"] = audio_data
+ texts = json_dict_raw["text"]
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
+ texts = random.choice(texts)
+ sample["raw_text"] = texts
+ sample["text"] = tokenize(texts)
+ return sample
+
+if __name__ == "__main__":
+ args = parse_args()
+ tar_path = args.tar_path
+ idx_list = list(range(args.start, args.end))
+ if args.exclude != None:
+ for x in args.exclude:
+ idx_list.remove(x)
+ if not args.order:
+ random.shuffle(idx_list)
+ if "aws" in tar_path:
+ args.local = False
+ if args.local:
+ input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list]
+ else:
+ input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list]
+ pipeline = [wds.SimpleShardList(input_shards)]
+ pipeline.extend(
+ [
+ wds.split_by_node,
+ wds.split_by_worker,
+ wds.tarfile_to_samples(handler=log_and_continue),
+ wds.map(preprocess),
+ wds.to_tuple("__url__", "__key__", "waveform"),
+ wds.batched(1),
+ ]
+ )
+ dataset = wds.DataPipeline(*pipeline)
+ dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
+ old_k = 0
+ old_batch = None
+ try:
+ for k, batch in tqdm(enumerate(dataloader)):
+ print("k:", k)
+ print("batch:", batch)
+ old_k = k
+ old_batch = copy.deepcopy(batch)
+ except:
+ with open("check_tar_log.txt","a") as file:
+ traceback.print_exc(file = file)
+ print("old_k:", old_k)
+ print("old_batch:", old_batch)
+ pass
diff --git a/src/tests/data_loader_test.py b/src/tests/data_loader_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..03be75e7ce16723053eb3d506b59f91160a7f3c7
--- /dev/null
+++ b/src/tests/data_loader_test.py
@@ -0,0 +1,60 @@
+from laion_clap import create_model
+from laion_clap.training.data import get_data
+from laion_clap.training import parse_args
+import torch
+import os
+from tqdm import tqdm
+from laion_clap.training.distributed import is_master, world_info_from_env
+from laion_clap.utils import dataset_split
+
+
+def run_dataloader():
+ for i, batch in enumerate(tqdm(dataloader, total=data["train"].dataloader.num_samples // args.batch_size)):
+ pass
+
+
+if __name__ == '__main__':
+
+ args = parse_args()
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ args.amodel = args.amodel.replace("/", "-")
+ device = torch.device('cpu')
+
+ # discover initial world args early so we can log properly
+ args.distributed = False
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+
+ if args.remotedata and is_master(args):
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ model, model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=True,
+ pretrained_audio=args.pretrained_audio,
+ pretrained_text=args.pretrained_text,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type
+ )
+
+ data = get_data(args, model_cfg)
+
+ dataloader, sampler = data["train"].dataloader, data["train"].sampler
+
+ print('dataset size:', data["train"].dataloader.num_samples)
+ print('batch size:', args.batch_size)
+ print('num batches:', data["train"].dataloader.num_samples // args.batch_size)
+
+ run_dataloader()