{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Torchaudio-Squim: Non-intrusive Speech Assessment in TorchAudio\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Author: [Anurag Kumar](anuragkr90@meta.com)_, [Zhaoheng\nNi](zni@meta.com)_\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Overview\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This tutorial shows uses of Torchaudio-Squim to estimate objective and\nsubjective metrics for assessment of speech quality and intelligibility.\n\nTorchAudio-Squim enables speech assessment in Torchaudio. It provides\ninterface and pre-trained models to estimate various speech quality and\nintelligibility metrics. Currently, Torchaudio-Squim [1] supports\nreference-free estimation 3 widely used objective metrics:\n\n- Wideband Perceptual Estimation of Speech Quality (PESQ) [2]\n\n- Short-Time Objective Intelligibility (STOI) [3]\n\n- Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) [4]\n\nIt also supports estimation of subjective Mean Opinion Score (MOS) for a\ngiven audio waveform using Non-Matching References [1, 5].\n\n**References**\n\n[1] Kumar, Anurag, et al.\u00a0\u201cTorchAudio-Squim: Reference-less Speech\nQuality and Intelligibility measures in TorchAudio.\u201d ICASSP 2023-2023\nIEEE International Conference on Acoustics, Speech and Signal Processing\n(ICASSP). IEEE, 2023.\n\n[2] I. Rec, \u201cP.862.2: Wideband extension to recommendation P.862 for the\nassessment of wideband telephone networks and speech codecs,\u201d\nInternational Telecommunication Union, CH\u2013Geneva, 2005.\n\n[3] Taal, C. H., Hendriks, R. C., Heusdens, R., & Jensen, J. (2010,\nMarch). A short-time objective intelligibility measure for\ntime-frequency weighted noisy speech. In 2010 IEEE international\nconference on acoustics, speech and signal processing (pp.\u00a04214-4217).\nIEEE.\n\n[4] Le Roux, Jonathan, et al.\u00a0\u201cSDR\u2013half-baked or well done?.\u201d ICASSP\n2019-2019 IEEE International Conference on Acoustics, Speech and Signal\nProcessing (ICASSP). IEEE, 2019.\n\n[5] Manocha, Pranay, and Anurag Kumar. \u201cSpeech quality assessment\nthrough MOS using non-matching references.\u201d Interspeech, 2022.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\nimport torchaudio\n\nprint(torch.__version__)\nprint(torchaudio.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Preparation\n\nFirst import the modules and define the helper functions.\n\nWe will need torch, torchaudio to use Torchaudio-squim, Matplotlib to\nplot data, pystoi, pesq for computing reference metrics.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"try:\n from pesq import pesq\n from pystoi import stoi\n from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE\nexcept ImportError:\n try:\n import google.colab # noqa: F401\n\n print(\n \"\"\"\n To enable running this notebook in Google Colab, install nightly\n torch and torchaudio builds by adding the following code block to the top\n of the notebook before running it:\n !pip3 uninstall -y torch torchvision torchaudio\n !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu\n !pip3 install pesq\n !pip3 install pystoi\n \"\"\"\n )\n except Exception:\n pass\n raise\n\n\nimport matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torchaudio.functional as F\nfrom IPython.display import Audio\nfrom torchaudio.utils import download_asset\n\n\ndef si_snr(estimate, reference, epsilon=1e-8):\n estimate = estimate - estimate.mean()\n reference = reference - reference.mean()\n reference_pow = reference.pow(2).mean(axis=1, keepdim=True)\n mix_pow = (estimate * reference).mean(axis=1, keepdim=True)\n scale = mix_pow / (reference_pow + epsilon)\n\n reference = scale * reference\n error = estimate - reference\n\n reference_pow = reference.pow(2)\n error_pow = error.pow(2)\n\n reference_pow = reference_pow.mean(axis=1)\n error_pow = error_pow.mean(axis=1)\n\n si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow)\n return si_snr.item()\n\n\ndef plot(waveform, title, sample_rate=16000):\n wav_numpy = waveform.numpy()\n\n sample_size = waveform.shape[1]\n time_axis = torch.arange(0, sample_size) / sample_rate\n\n figure, axes = plt.subplots(2, 1)\n axes[0].plot(time_axis, wav_numpy[0], linewidth=1)\n axes[0].grid(True)\n axes[1].specgram(wav_numpy[0], Fs=sample_rate)\n figure.suptitle(title)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load Speech and Noise Sample\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"SAMPLE_SPEECH = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\")\nSAMPLE_NOISE = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"WAVEFORM_SPEECH, SAMPLE_RATE_SPEECH = torchaudio.load(SAMPLE_SPEECH)\nWAVEFORM_NOISE, SAMPLE_RATE_NOISE = torchaudio.load(SAMPLE_NOISE)\nWAVEFORM_NOISE = WAVEFORM_NOISE[0:1, :]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently, Torchaudio-Squim model only supports 16000 Hz sampling rate.\nResample the waveforms if necessary.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if SAMPLE_RATE_SPEECH != 16000:\n WAVEFORM_SPEECH = F.resample(WAVEFORM_SPEECH, SAMPLE_RATE_SPEECH, 16000)\n\nif SAMPLE_RATE_NOISE != 16000:\n WAVEFORM_NOISE = F.resample(WAVEFORM_NOISE, SAMPLE_RATE_NOISE, 16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trim waveforms so that they have the same number of frames.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if WAVEFORM_SPEECH.shape[1] < WAVEFORM_NOISE.shape[1]:\n WAVEFORM_NOISE = WAVEFORM_NOISE[:, : WAVEFORM_SPEECH.shape[1]]\nelse:\n WAVEFORM_SPEECH = WAVEFORM_SPEECH[:, : WAVEFORM_NOISE.shape[1]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Play speech sample\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"Audio(WAVEFORM_SPEECH.numpy()[0], rate=16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Play noise sample\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"Audio(WAVEFORM_NOISE.numpy()[0], rate=16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Create distorted (noisy) speech samples\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"snr_dbs = torch.tensor([20, -5])\nWAVEFORM_DISTORTED = F.add_noise(WAVEFORM_SPEECH, WAVEFORM_NOISE, snr_dbs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Play distorted speech with 20dB SNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"Audio(WAVEFORM_DISTORTED.numpy()[0], rate=16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Play distorted speech with -5dB SNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Visualize the waveforms\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize speech sample\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plot(WAVEFORM_SPEECH, \"Clean Speech\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize noise sample\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plot(WAVEFORM_NOISE, \"Noise\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize distorted speech with 20dB SNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plot(WAVEFORM_DISTORTED[0:1], f\"Distorted Speech with {snr_dbs[0]}dB SNR\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Visualize distorted speech with -5dB SNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plot(WAVEFORM_DISTORTED[1:2], f\"Distorted Speech with {snr_dbs[1]}dB SNR\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Predict Objective Metrics\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the pre-trained ``SquimObjective``\\ model.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"objective_model = SQUIM_OBJECTIVE.get_model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compare model outputs with ground truths for distorted speech with 20dB\nSNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"stoi_hyp, pesq_hyp, si_sdr_hyp = objective_model(WAVEFORM_DISTORTED[0:1, :])\nprint(f\"Estimated metrics for distorted speech at {snr_dbs[0]}dB are\\n\")\nprint(f\"STOI: {stoi_hyp[0]}\")\nprint(f\"PESQ: {pesq_hyp[0]}\")\nprint(f\"SI-SDR: {si_sdr_hyp[0]}\\n\")\n\npesq_ref = pesq(16000, WAVEFORM_SPEECH[0].numpy(), WAVEFORM_DISTORTED[0].numpy(), mode=\"wb\")\nstoi_ref = stoi(WAVEFORM_SPEECH[0].numpy(), WAVEFORM_DISTORTED[0].numpy(), 16000, extended=False)\nsi_sdr_ref = si_snr(WAVEFORM_DISTORTED[0:1], WAVEFORM_SPEECH)\nprint(f\"Reference metrics for distorted speech at {snr_dbs[0]}dB are\\n\")\nprint(f\"STOI: {stoi_ref}\")\nprint(f\"PESQ: {pesq_ref}\")\nprint(f\"SI-SDR: {si_sdr_ref}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compare model outputs with ground truths for distorted speech with -5dB\nSNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"stoi_hyp, pesq_hyp, si_sdr_hyp = objective_model(WAVEFORM_DISTORTED[1:2, :])\nprint(f\"Estimated metrics for distorted speech at {snr_dbs[1]}dB are\\n\")\nprint(f\"STOI: {stoi_hyp[0]}\")\nprint(f\"PESQ: {pesq_hyp[0]}\")\nprint(f\"SI-SDR: {si_sdr_hyp[0]}\\n\")\n\npesq_ref = pesq(16000, WAVEFORM_SPEECH[0].numpy(), WAVEFORM_DISTORTED[1].numpy(), mode=\"wb\")\nstoi_ref = stoi(WAVEFORM_SPEECH[0].numpy(), WAVEFORM_DISTORTED[1].numpy(), 16000, extended=False)\nsi_sdr_ref = si_snr(WAVEFORM_DISTORTED[1:2], WAVEFORM_SPEECH)\nprint(f\"Reference metrics for distorted speech at {snr_dbs[1]}dB are\\n\")\nprint(f\"STOI: {stoi_ref}\")\nprint(f\"PESQ: {pesq_ref}\")\nprint(f\"SI-SDR: {si_sdr_ref}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Predict Mean Opinion Scores (Subjective) Metric\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the pre-trained ``SquimSubjective`` model.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"subjective_model = SQUIM_SUBJECTIVE.get_model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load a non-matching reference (NMR)\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"NMR_SPEECH = download_asset(\"tutorial-assets/ctc-decoding/1688-142285-0007.wav\")\n\nWAVEFORM_NMR, SAMPLE_RATE_NMR = torchaudio.load(NMR_SPEECH)\nif SAMPLE_RATE_NMR != 16000:\n WAVEFORM_NMR = F.resample(WAVEFORM_NMR, SAMPLE_RATE_NMR, 16000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute MOS metric for distorted speech with 20dB SNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"mos = subjective_model(WAVEFORM_DISTORTED[0:1, :], WAVEFORM_NMR)\nprint(f\"Estimated MOS for distorted speech at {snr_dbs[0]}dB is MOS: {mos[0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute MOS metric for distorted speech with -5dB SNR\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"mos = subjective_model(WAVEFORM_DISTORTED[1:2, :], WAVEFORM_NMR)\nprint(f\"Estimated MOS for distorted speech at {snr_dbs[1]}dB is MOS: {mos[0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Comparison with ground truths and baselines\n\nVisualizing the estimated metrics by the ``SquimObjective`` and\n``SquimSubjective`` models can help users better understand how the\nmodels can be applicable in real scenario. The graph below shows scatter\nplots of three different systems: MOSA-Net [1], AMSA [2], and the\n``SquimObjective`` model, where y axis represents the estimated STOI,\nPESQ, and Si-SDR scores, and x axis represents the corresponding ground\ntruth.\n\n
\n\n[1] Zezario, Ryandhimas E., Szu-Wei Fu, Fei Chen, Chiou-Shann Fuh,\nHsin-Min Wang, and Yu Tsao. \u201cDeep learning-based non-intrusive\nmulti-objective speech assessment model with cross-domain features.\u201d\nIEEE/ACM Transactions on Audio, Speech, and Language Processing 31\n(2022): 54-70.\n\n[2] Dong, Xuan, and Donald S. Williamson. \u201cAn attention enhanced\nmulti-task model for objective speech assessment in real-world\nenvironments.\u201d In ICASSP 2020-2020 IEEE International Conference on\nAcoustics, Speech and Signal Processing (ICASSP), pp.\u00a0911-915. IEEE,\n2020.\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The graph below shows scatter plot of the ``SquimSubjective`` model,\nwhere y axis represents the estimated MOS metric score, and x axis\nrepresents the corresponding ground truth.\n\n
\n\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 0
}