MusiLingo
Collection
This is the checkpoints and datasets of MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response
•
5 items
•
Updated
•
3
The model consists of a music encoder MERT-v1-300M
, a natural language decoder vicuna-7b-delta-v0
, and a linear projection laer between the two.
This checkpoint of MusiLingo is developed on the MusicQA and can answer instructions with music raw audio, such as querying about the tempo, emotion, genre, tags or subjective feelings etc. You can use the MusicQA dataset for the following demo. For the implementation of MusicQA, please refer to our Github repo.
from tqdm.auto import tqdm
import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList
def load_audio(
file_path,
target_sr,
is_mono=True,
is_normalize=False,
crop_to_length_in_sec=None,
crop_to_length_in_sample_points=None,
crop_randomly=False,
pad=False,
return_start=False,
device=torch.device('cpu')
):
"""Load audio file and convert to target sample rate.
Supports cropping and padding.
Args:
file_path (str): path to audio file
target_sr (int): target sample rate, if not equal to sample rate of audio file, resample to target_sr
is_mono (bool, optional): convert to mono. Defaults to True.
is_normalize (bool, optional): normalize to [-1, 1]. Defaults to False.
crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None. Note that the crop length in sample points is calculated before resampling.
crop_randomly (bool, optional): crop randomly. Defaults to False.
pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.
device (torch.device, optional): device to use for resampling. Defaults to torch.device('cpu').
Returns:
torch.Tensor: waveform of shape (1, n_sample)
"""
# TODO: deal with target_depth
try:
waveform, sample_rate = torchaudio.load(file_path)
except Exception as e:
waveform, sample_rate = torchaudio.backend.soundfile_backend.load(file_path)
if waveform.shape[0] > 1:
if is_mono:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if is_normalize:
waveform = waveform / waveform.abs().max()
waveform, start = crop_audio(
waveform,
sample_rate,
crop_to_length_in_sec=crop_to_length_in_sec,
crop_to_length_in_sample_points=crop_to_length_in_sample_points,
crop_randomly=crop_randomly,
pad=pad,
)
if sample_rate != target_sr:
resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
waveform = waveform.to(device)
resampler = resampler.to(device)
waveform = resampler(waveform)
if return_start:
return waveform, start
return waveform
def crop_audio(
waveform,
sample_rate,
crop_to_length_in_sec=None,
crop_to_length_in_sample_points=None,
crop_randomly=False,
pad=False,
):
"""Crop waveform to specified length in seconds or sample points.
Supports random cropping and padding.
Args:
waveform (torch.Tensor): waveform of shape (1, n_sample)
sample_rate (int): sample rate of waveform
crop_to_length_in_sec (float, optional): crop to specified length in seconds. Defaults to None.
crop_to_length_in_sample_points (int, optional): crop to specified length in sample points. Defaults to None.
crop_randomly (bool, optional): crop randomly. Defaults to False.
pad (bool, optional): pad to specified length if waveform is shorter than specified length. Defaults to False.
Returns:
torch.Tensor: cropped waveform
int: start index of cropped waveform in original waveform
"""
assert crop_to_length_in_sec is None or crop_to_length_in_sample_points is None, \
"Only one of crop_to_length_in_sec and crop_to_length_in_sample_points can be specified"
# convert crop length to sample points
crop_duration_in_sample = None
if crop_to_length_in_sec:
crop_duration_in_sample = int(sample_rate * crop_to_length_in_sec)
elif crop_to_length_in_sample_points:
crop_duration_in_sample = crop_to_length_in_sample_points
# crop
start = 0
if crop_duration_in_sample:
if waveform.shape[-1] > crop_duration_in_sample:
if crop_randomly:
start = random.randint(0, waveform.shape[-1] - crop_duration_in_sample)
waveform = waveform[..., start:start + crop_duration_in_sample]
elif waveform.shape[-1] < crop_duration_in_sample:
if pad:
waveform = torch.nn.functional.pad(waveform, (0, crop_duration_in_sample - waveform.shape[-1]))
return waveform, start
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):
audio = load_audio(audio_path, target_sr=24000,
is_mono=True,
is_normalize=False,
crop_to_length_in_sample_points=int(30*16000)+1,
crop_randomly=True,
pad=False).cuda()
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
audio = processor(audio,
sampling_rate=24000,
return_tensors="pt")['input_values'][0].cuda()
audio_embeds, atts_audio = model.encode_audio(audio)
prompt = '<Audio><AudioHere></Audio> ' + text
instruction_prompt = [model.prompt_template.format(prompt)]
audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
model.llama_tokenizer.padding_side = "right"
batch_size = audio_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=torch.long,
device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
bos_embeds = model.llama_model.model.embed_tokens(bos)
# atts_bos = atts_audio[:, :1]
inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
# attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
outputs = model.llama_model.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=max_new_tokens,
stopping_criteria=stopping,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
return output_text
musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1", trust_remote_code=True)
musilingo.to("cuda")
musilingo.eval()
prompt = "this is the task instruction and input question for MusiLingo model"
audio = "/path/to/the/24kHz-audio"
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
torch.tensor([2277, 29937]).cuda()])])
response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
If you find the work useful for your research, please consider citing it using the following BibTeX entry:
@inproceedings{deng2024musilingo,
title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
year={2024},
organization={Association for Computational Linguistics}
}