Skip to content

Wav2Rec 🎸

Overview

Wav2Rec is a library for music recommendation based on recent advances in self-supervised neural networks.

Installation

pip install git+git://github.com/TariqAHassan/wav2rec@main

Requires Python 3.7+

How it Works

Wav2Rec is built on top of recently developed techniques for self-supervised learning, whereby rich representations can be learned from data without explict labels. In particular, Wav2Rec leverages the simple siamese (or SimSam) neural network architecture proposed by Chen and He (2020), which is trained with the objective of maximizing the similarity between two augmentations of the same image.

In order to adapt SimSam to work with audio, Wav2Rec introduces two modifications. First, raw audio waveforms are converted into (mel)spectrograms, which can be seen as a form of image. This adaption allows the use of a standard image model encoders, such as ResNet50 or Vision Transformer (see audionets.py). Second, while spectrograms can been seen as form of image, in actuality their statistical properties are quite different from those found in natural images. For instance, because spectrograms have a temporal structure, flipping along this temporal dimension is not a coherent augmentation to perform. Thus, only augmentations which respect the unique statistical properties of spectrograms have been used (see transforms.py).

Once trained, music recommendation is simply a matter of performing nearest neighbour search on the projections obtained from the model.

Quick Start

Training

The Wav2RecNet model, which underlies Wav2Rec() (below), can be trained using any audio dataset. For an example of training the model using the FMA dataset see experiments/fma/train.ipynb.

Inference

The Wav2Rec() class along with a Wav2RecDataset() dataset can be used to generate recommendations of similar music.

from pathlib import Path
from wav2rec import Wav2Rec, Wav2RecDataset

MUSIC_PATH = Path("music")  # directory of music
MODEL_PATH = Path("checkpoints/my_trained_model.ckpt")  # trained model

my_dataset = Wav2RecDataset(MUSIC_PATH, ext="mp3").scan()

model = Wav2Rec(MODEL_PATH)
model.fit(my_dataset)

Once fit, we can load a piece of sample piece of audio

waveform = my_dataset.load_audio("my_song.mp3")

and get some recommendations for similar music.

metrics, paths = model.recommend(waveform, n=3)

Above, metrics is a 2D array which stores the similarity metrics (cosine similarity by default) between waveform and each recommendation. The paths object is also a 2D array, but it contains the paths to the recommended music files.

Note: To get an intuition for the representations that will underlie these recommendations, check out experiments/fma/inference.ipynb.