{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Speech Recognition with Wav2Vec2\n\n**Author**: [Moto Hira](moto@meta.com)_\n\nThis tutorial shows how to perform speech recognition using using\npre-trained models from wav2vec 2.0\n[[paper](https://arxiv.org/abs/2006.11477)_].\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n\nThe process of speech recognition looks like the following.\n\n1. Extract the acoustic features from audio waveform\n\n2. Estimate the class of the acoustic features frame-by-frame\n\n3. Generate hypothesis from the sequence of the class probabilities\n\nTorchaudio provides easy access to the pre-trained weights and\nassociated information, such as the expected sample rate and class\nlabels. They are bundled together and available under\n:py:mod:`torchaudio.pipelines` module.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparation\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchaudio\n\nprint(torch.__version__)\nprint(torchaudio.__version__)\n\ntorch.random.manual_seed(0)\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nprint(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import IPython\nimport matplotlib.pyplot as plt\nfrom torchaudio.utils import download_asset\n\nSPEECH_FILE = download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a pipeline\n\nFirst, we will create a Wav2Vec2 model that performs the feature\nextraction and the classification.\n\nThere are two types of Wav2Vec2 pre-trained weights available in\ntorchaudio. The ones fine-tuned for ASR task, and the ones not\nfine-tuned.\n\nWav2Vec2 (and HuBERT) models are trained in self-supervised manner. They\nare firstly trained with audio only for representation learning, then\nfine-tuned for a specific task with additional labels.\n\nThe pre-trained weights without fine-tuning can be fine-tuned\nfor other downstream tasks as well, but this tutorial does not\ncover that.\n\nWe will use :py:data:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H` here.\n\nThere are multiple pre-trained models available in :py:mod:`torchaudio.pipelines`.\nPlease check the documentation for the detail of how they are trained.\n\nThe bundle object provides the interface to instantiate model and other\ninformation. Sampling rate and the class labels are found as follow.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H\n\nprint(\"Sample Rate:\", bundle.sample_rate)\n\nprint(\"Labels:\", bundle.get_labels())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model can be constructed as following. This process will automatically\nfetch the pre-trained weights and load it into the model.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = bundle.get_model().to(device)\n\nprint(model.__class__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading data\n\nWe will use the speech data from [VOiCES\ndataset](https://iqtlabs.github.io/voices/)_, which is licensed under\nCreative Commos BY 4.0.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "IPython.display.Audio(SPEECH_FILE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load data, we use :py:func:`torchaudio.load`.\n\nIf the sampling rate is different from what the pipeline expects, then\nwe can use :py:func:`torchaudio.functional.resample` for resampling.\n\n

Note

- :py:func:`torchaudio.functional.resample` works on CUDA tensors as well.\n - When performing resampling multiple times on the same set of sample rates,\n using :py:class:`torchaudio.transforms.Resample` might improve the performace.

\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "waveform, sample_rate = torchaudio.load(SPEECH_FILE)\nwaveform = waveform.to(device)\n\nif sample_rate != bundle.sample_rate:\n waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extracting acoustic features\n\nThe next step is to extract acoustic features from the audio.\n\n

Note

Wav2Vec2 models fine-tuned for ASR task can perform feature\n extraction and classification with one step, but for the sake of the\n tutorial, we also show how to perform feature extraction here.

\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.inference_mode():\n features, _ = model.extract_features(waveform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The returned features is a list of tensors. Each tensor is the output of\na transformer layer.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, ax = plt.subplots(len(features), 1, figsize=(16, 4.3 * len(features)))\nfor i, feats in enumerate(features):\n ax[i].imshow(feats[0].cpu(), interpolation=\"nearest\")\n ax[i].set_title(f\"Feature from transformer layer {i+1}\")\n ax[i].set_xlabel(\"Feature dimension\")\n ax[i].set_ylabel(\"Frame (time-axis)\")\nfig.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature classification\n\nOnce the acoustic features are extracted, the next step is to classify\nthem into a set of categories.\n\nWav2Vec2 model provides method to perform the feature extraction and\nclassification in one step.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.inference_mode():\n emission, _ = model(waveform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output is in the form of logits. It is not in the form of\nprobability.\n\nLet\u2019s visualize this.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.imshow(emission[0].cpu().T, interpolation=\"nearest\")\nplt.title(\"Classification result\")\nplt.xlabel(\"Frame (time-axis)\")\nplt.ylabel(\"Class\")\nplt.tight_layout()\nprint(\"Class labels:\", bundle.get_labels())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that there are strong indications to certain labels across\nthe time line.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generating transcripts\n\nFrom the sequence of label probabilities, now we want to generate\ntranscripts. The process to generate hypotheses is often called\n\u201cdecoding\u201d.\n\nDecoding is more elaborate than simple classification because\ndecoding at certain time step can be affected by surrounding\nobservations.\n\nFor example, take a word like ``night`` and ``knight``. Even if their\nprior probability distribution are differnt (in typical conversations,\n``night`` would occur way more often than ``knight``), to accurately\ngenerate transcripts with ``knight``, such as ``a knight with a sword``,\nthe decoding process has to postpone the final decision until it sees\nenough context.\n\nThere are many decoding techniques proposed, and they require external\nresources, such as word dictionary and language models.\n\nIn this tutorial, for the sake of simplicity, we will perform greedy\ndecoding which does not depend on such external components, and simply\npick up the best hypothesis at each time step. Therefore, the context\ninformation are not used, and only one transcript can be generated.\n\nWe start by defining greedy decoding algorithm.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class GreedyCTCDecoder(torch.nn.Module):\n def __init__(self, labels, blank=0):\n super().__init__()\n self.labels = labels\n self.blank = blank\n\n def forward(self, emission: torch.Tensor) -> str:\n \"\"\"Given a sequence emission over labels, get the best path string\n Args:\n emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.\n\n Returns:\n str: The resulting transcript\n \"\"\"\n indices = torch.argmax(emission, dim=-1) # [num_seq,]\n indices = torch.unique_consecutive(indices, dim=-1)\n indices = [i for i in indices if i != self.blank]\n return \"\".join([self.labels[i] for i in indices])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now create the decoder object and decode the transcript.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "decoder = GreedyCTCDecoder(labels=bundle.get_labels())\ntranscript = decoder(emission[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\u2019s check the result and listen again to the audio.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(transcript)\nIPython.display.Audio(SPEECH_FILE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The ASR model is fine-tuned using a loss function called Connectionist Temporal Classification (CTC).\nThe detail of CTC loss is explained\n[here](https://distill.pub/2017/ctc/)_. In CTC a blank token (\u03f5) is a\nspecial token which represents a repetition of the previous symbol. In\ndecoding, these are simply ignored.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n\nIn this tutorial, we looked at how to use :py:class:`~torchaudio.pipelines.Wav2Vec2ASRBundle` to\nperform acoustic feature extraction and speech recognition. Constructing\na model and getting the emission is as short as two lines.\n\n::\n\n model = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()\n emission = model(waveforms, ...)\n\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }