{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Speech Enhancement with MVDR Beamforming\n\n**Author**: [Zhaoheng Ni](zni@meta.com)_\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Overview\n\nThis is a tutorial on applying Minimum Variance Distortionless\nResponse (MVDR) beamforming to estimate enhanced speech with\nTorchAudio.\n\nSteps:\n\n- Generate an ideal ratio mask (IRM) by dividing the clean/noise\n magnitude by the mixture magnitude.\n- Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.\n- Estimate enhanced speech using MVDR modules\n (:py:func:`torchaudio.transforms.SoudenMVDR` and\n :py:func:`torchaudio.transforms.RTFMVDR`).\n- Benchmark the two methods\n (:py:func:`torchaudio.functional.rtf_evd` and\n :py:func:`torchaudio.functional.rtf_power`) for computing the\n relative transfer function (RTF) matrix of the reference microphone.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchaudio\nimport torchaudio.functional as F\n\nprint(torch.__version__)\nprint(torchaudio.__version__)\n\n\nimport matplotlib.pyplot as plt\nimport mir_eval\nfrom IPython.display import Audio" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Preparation\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1. Import the packages\n\nFirst, we install and import the necessary packages.\n\n``mir_eval``, ``pesq``, and ``pystoi`` packages are required for\nevaluating the speech enhancement performance.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# When running this example in notebook, install the following packages.\n# !pip3 install mir_eval\n# !pip3 install pesq\n# !pip3 install pystoi\n\nfrom pesq import pesq\nfrom pystoi import stoi\nfrom torchaudio.utils import download_asset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2. Download audio data\n\nThe multi-channel audio example is selected from\n[ConferencingSpeech](https://github.com/ConferencingSpeech/ConferencingSpeech2021)_\ndataset.\n\nThe original filename is\n\n ``SSB07200001\\#noise-sound-bible-0038\\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\\#15217\\#25.16333303751458\\#0.2101221178590021.wav``\n\nwhich was generated with:\n\n- ``SSB07200001.wav`` from\n [AISHELL-3](https://www.openslr.org/93/)_ (Apache License\n v.2.0)\n- ``noise-sound-bible-0038.wav`` from\n [MUSAN](http://www.openslr.org/17/)_ (Attribution 4.0\n International \u2014 CC BY 4.0)\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "SAMPLE_RATE = 16000\nSAMPLE_CLEAN = download_asset(\"tutorial-assets/mvdr/clean_speech.wav\")\nSAMPLE_NOISE = download_asset(\"tutorial-assets/mvdr/noise.wav\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3. Helper functions\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_spectrogram(stft, title=\"Spectrogram\"):\n magnitude = stft.abs()\n spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()\n figure, axis = plt.subplots(1, 1)\n img = axis.imshow(spectrogram, cmap=\"viridis\", vmin=-100, vmax=0, origin=\"lower\", aspect=\"auto\")\n axis.set_title(title)\n plt.colorbar(img, ax=axis)\n\n\ndef plot_mask(mask, title=\"Mask\"):\n mask = mask.numpy()\n figure, axis = plt.subplots(1, 1)\n img = axis.imshow(mask, cmap=\"viridis\", origin=\"lower\", aspect=\"auto\")\n axis.set_title(title)\n plt.colorbar(img, ax=axis)\n\n\ndef si_snr(estimate, reference, epsilon=1e-8):\n estimate = estimate - estimate.mean()\n reference = reference - reference.mean()\n reference_pow = reference.pow(2).mean(axis=1, keepdim=True)\n mix_pow = (estimate * reference).mean(axis=1, keepdim=True)\n scale = mix_pow / (reference_pow + epsilon)\n\n reference = scale * reference\n error = estimate - reference\n\n reference_pow = reference.pow(2)\n error_pow = error.pow(2)\n\n reference_pow = reference_pow.mean(axis=1)\n error_pow = error_pow.mean(axis=1)\n\n si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)\n return si_snr.item()\n\n\ndef generate_mixture(waveform_clean, waveform_noise, target_snr):\n power_clean_signal = waveform_clean.pow(2).mean()\n power_noise_signal = waveform_noise.pow(2).mean()\n current_snr = 10 * torch.log10(power_clean_signal / power_noise_signal)\n waveform_noise *= 10 ** (-(target_snr - current_snr) / 20)\n return waveform_clean + waveform_noise\n\n\ndef evaluate(estimate, reference):\n si_snr_score = si_snr(estimate, reference)\n (\n sdr,\n _,\n _,\n _,\n ) = mir_eval.separation.bss_eval_sources(reference.numpy(), estimate.numpy(), False)\n pesq_mix = pesq(SAMPLE_RATE, estimate[0].numpy(), reference[0].numpy(), \"wb\")\n stoi_mix = stoi(reference[0].numpy(), estimate[0].numpy(), SAMPLE_RATE, extended=False)\n print(f\"SDR score: {sdr[0]}\")\n print(f\"Si-SNR score: {si_snr_score}\")\n print(f\"PESQ score: {pesq_mix}\")\n print(f\"STOI score: {stoi_mix}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Generate Ideal Ratio Masks (IRMs)\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.1. Load audio data\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)\nwaveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)\nassert sr == sr2 == SAMPLE_RATE\n# The mixture waveform is a combination of clean and noise waveforms with a desired SNR.\ntarget_snr = 3\nwaveform_mix = generate_mixture(waveform_clean, waveform_noise, target_snr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: To improve computational robustness, it is recommended to represent\nthe waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "waveform_mix = waveform_mix.to(torch.double)\nwaveform_clean = waveform_clean.to(torch.double)\nwaveform_noise = waveform_noise.to(torch.double)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.2. Compute STFT coefficients\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N_FFT = 1024\nN_HOP = 256\nstft = torchaudio.transforms.Spectrogram(\n n_fft=N_FFT,\n hop_length=N_HOP,\n power=None,\n)\nistft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)\n\nstft_mix = stft(waveform_mix)\nstft_clean = stft(waveform_clean)\nstft_noise = stft(waveform_noise)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2.1. Visualize mixture speech\n\nWe evaluate the quality of the mixture speech or the enhanced speech\nusing the following three metrics:\n\n- signal-to-distortion ratio (SDR)\n- scale-invariant signal-to-noise ratio (Si-SNR, or Si-SDR in some papers)\n- Perceptual Evaluation of Speech Quality (PESQ)\n\nWe also evaluate the intelligibility of the speech with the Short-Time Objective Intelligibility\n(STOI) metric.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_spectrogram(stft_mix[0], \"Spectrogram of Mixture Speech (dB)\")\nevaluate(waveform_mix[0:1], waveform_clean[0:1])\nAudio(waveform_mix[0], rate=SAMPLE_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2.2. Visualize clean speech\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_spectrogram(stft_clean[0], \"Spectrogram of Clean Speech (dB)\")\nAudio(waveform_clean[0], rate=SAMPLE_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.2.3. Visualize noise\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_spectrogram(stft_noise[0], \"Spectrogram of Noise (dB)\")\nAudio(waveform_noise[0], rate=SAMPLE_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.3. Define the reference microphone\n\nWe choose the first microphone in the array as the reference channel for demonstration.\nThe selection of the reference channel may depend on the design of the microphone array.\n\nYou can also apply an end-to-end neural network which estimates both the reference channel and\nthe PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "REFERENCE_CHANNEL = 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3.4. Compute IRMs\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_irms(stft_clean, stft_noise):\n mag_clean = stft_clean.abs() ** 2\n mag_noise = stft_noise.abs() ** 2\n irm_speech = mag_clean / (mag_clean + mag_noise)\n irm_noise = mag_noise / (mag_clean + mag_noise)\n return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]\n\n\nirm_speech, irm_noise = get_irms(stft_clean, stft_noise)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.4.1. Visualize IRM of target speech\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_mask(irm_speech, \"IRM of the Target Speech\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.4.2. Visualize IRM of noise\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_mask(irm_noise, \"IRM of the Noise\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Compute PSD matrices\n\n:py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given\nthe multi-channel complex-valued STFT coefficients of the mixture speech\nand the time-frequency mask.\n\nThe shape of the PSD matrix is `(..., freq, channel, channel)`.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "psd_transform = torchaudio.transforms.PSD()\n\npsd_speech = psd_transform(stft_mix, irm_speech)\npsd_noise = psd_transform(stft_mix, irm_noise)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Beamforming using SoudenMVDR\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.1. Apply beamforming\n\n:py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel\ncomplexed-valued STFT coefficients of the mixture speech, PSD matrices of\ntarget speech and noise, and the reference channel inputs.\n\nThe output is a single-channel complex-valued STFT coefficients of the enhanced speech.\nWe can then obtain the enhanced waveform by passing this output to the\n:py:func:`torchaudio.transforms.InverseSpectrogram` module.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mvdr_transform = torchaudio.transforms.SoudenMVDR()\nstft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)\nwaveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.2. Result for SoudenMVDR\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_spectrogram(stft_souden, \"Enhanced Spectrogram by SoudenMVDR (dB)\")\nwaveform_souden = waveform_souden.reshape(1, -1)\nevaluate(waveform_souden, waveform_clean[0:1])\nAudio(waveform_souden, rate=SAMPLE_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Beamforming using RTFMVDR\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.1. Compute RTF\n\nTorchAudio offers two methods for computing the RTF matrix of a\ntarget speech:\n\n- :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue\n decomposition to the PSD matrix of target speech to get the RTF matrix.\n\n- :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration\n method. You can specify the number of iterations with argument ``n_iter``.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rtf_evd = F.rtf_evd(psd_speech)\nrtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.2. Apply beamforming\n\n:py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel\ncomplexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,\nPSD matrix of noise, and the reference channel inputs.\n\nThe output is a single-channel complex-valued STFT coefficients of the enhanced speech.\nWe can then obtain the enhanced waveform by passing this output to the\n:py:func:`torchaudio.transforms.InverseSpectrogram` module.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mvdr_transform = torchaudio.transforms.RTFMVDR()\n\n# compute the enhanced speech based on F.rtf_evd\nstft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)\nwaveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])\n\n# compute the enhanced speech based on F.rtf_power\nstft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)\nwaveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.3. Result for RTFMVDR with `rtf_evd`\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_spectrogram(stft_rtf_evd, \"Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)\")\nwaveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)\nevaluate(waveform_rtf_evd, waveform_clean[0:1])\nAudio(waveform_rtf_evd, rate=SAMPLE_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6.4. Result for RTFMVDR with `rtf_power`\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_spectrogram(stft_rtf_power, \"Enhanced Spectrogram by RTFMVDR and F.rtf_power (dB)\")\nwaveform_rtf_power = waveform_rtf_power.reshape(1, -1)\nevaluate(waveform_rtf_power, waveform_clean[0:1])\nAudio(waveform_rtf_power, rate=SAMPLE_RATE)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }