{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# ASR Inference with CTC Decoder\n\n**Author**: [Caroline Chen](carolinechen@meta.com)_\n\nThis tutorial shows how to perform speech recognition inference using a\nCTC beam search decoder with lexicon constraint and KenLM language model\nsupport. We demonstrate this on a pretrained wav2vec 2.0 model trained\nusing CTC loss.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n\nBeam search decoding works by iteratively expanding text hypotheses (beams)\nwith next possible characters, and\u00a0maintaining only the hypotheses with the\nhighest scores at each time step. A language model can be incorporated into\nthe scoring computation, and adding a lexicon constraint restricts the\nnext possible tokens for the hypotheses so that only words from the lexicon\ncan be generated.\n\nThe underlying implementation is ported from [Flashlight](https://arxiv.org/pdf/2201.12465.pdf)_'s\nbeam search decoder. A mathematical formula for the decoder optimization can be\nfound in the [Wav2Letter paper](https://arxiv.org/pdf/1609.03193.pdf)_, and\na more detailed algorithm can be found in this [blog](https://towardsdatascience.com/boosting-your-sequence-generation-performance-with-beam-search-language-model-decoding-74ee64de435a)_.\n\nRunning ASR inference using a CTC Beam Search decoder with a language\nmodel and lexicon constraint requires the following components\n\n- Acoustic Model: model predicting phonetics from audio waveforms\n- Tokens: the possible predicted tokens from the acoustic model\n- Lexicon: mapping between possible words and their corresponding\n tokens sequence\n- Language Model (LM): n-gram language model trained with the [KenLM\n library](https://kheafield.com/code/kenlm/)_, or custom language\n model that inherits :py:class:`~torchaudio.models.decoder.CTCDecoderLM`\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Acoustic Model and Set Up\n\nFirst we import the necessary utilities and fetch the data that we are\nworking with\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchaudio\n\nprint(torch.__version__)\nprint(torchaudio.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import time\nfrom typing import List\n\nimport IPython\nimport matplotlib.pyplot as plt\nfrom torchaudio.models.decoder import ctc_decoder\nfrom torchaudio.utils import download_asset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the pretrained [Wav2Vec 2.0](https://arxiv.org/abs/2006.11477)_\nBase model that is finetuned on 10 min of the [LibriSpeech\ndataset](http://www.openslr.org/12)_, which can be loaded in using\n:data:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M`.\nFor more detail on running Wav2Vec 2.0 speech\nrecognition pipelines in torchaudio, please refer to [this\ntutorial](./speech_recognition_pipeline_tutorial.html)_.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_10M\nacoustic_model = bundle.get_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will load a sample from the LibriSpeech test-other dataset.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "speech_file = download_asset(\"tutorial-assets/ctc-decoding/1688-142285-0007.wav\")\n\nIPython.display.Audio(speech_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The transcript corresponding to this audio file is\n\n```\ni really was very much afraid of showing him how much shocked i was at some parts of what he said\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "waveform, sample_rate = torchaudio.load(speech_file)\n\nif sample_rate != bundle.sample_rate:\n waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Files and Data for Decoder\n\nNext, we load in our token, lexicon, and language model data, which are used\nby the decoder to predict words from the acoustic model output. Pretrained\nfiles for the LibriSpeech dataset can be downloaded through torchaudio,\nor the user can provide their own files.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokens\n\nThe tokens are the possible symbols that the acoustic model can predict,\nincluding the blank and silent symbols. It can either be passed in as a\nfile, where each line consists of the tokens corresponding to the same\nindex, or as a list of tokens, each mapping to a unique index.\n\n```\n# tokens.txt\n_\n|\ne\nt\n...\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tokens = [label.lower() for label in bundle.get_labels()]\nprint(tokens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Lexicon\n\nThe lexicon is a mapping from words to their corresponding tokens\nsequence, and is used to restrict the search space of the decoder to\nonly words from the lexicon. The expected format of the lexicon file is\na line per word, with a word followed by its space-split tokens.\n\n```\n# lexcion.txt\na a |\nable a b l e |\nabout a b o u t |\n...\n...\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Language Model\n\nA language model can be used in decoding to improve the results, by\nfactoring in a language model score that represents the likelihood of\nthe sequence into the beam search computation. Below, we outline the\ndifferent forms of language models that are supported for decoding.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### No Language Model\n\nTo create a decoder instance without a language model, set `lm=None`\nwhen initializing the decoder.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### KenLM\n\nThis is an n-gram language model trained with the [KenLM\nlibrary](https://kheafield.com/code/kenlm/)_. Both the ``.arpa`` or\nthe binarized ``.bin`` LM can be used, but the binary format is\nrecommended for faster loading.\n\nThe language model used in this tutorial is a 4-gram KenLM trained using\n[LibriSpeech](http://www.openslr.org/11)_.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Custom Language Model\n\nUsers can define their own custom language model in Python, whether\nit be a statistical or neural network language model, using\n:py:class:`~torchaudio.models.decoder.CTCDecoderLM` and\n:py:class:`~torchaudio.models.decoder.CTCDecoderLMState`.\n\nFor instance, the following code creates a basic wrapper around a PyTorch\n``torch.nn.Module`` language model.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState\n\n\nclass CustomLM(CTCDecoderLM):\n \"\"\"Create a Python wrapper around `language_model` to feed to the decoder.\"\"\"\n\n def __init__(self, language_model: torch.nn.Module):\n CTCDecoderLM.__init__(self)\n self.language_model = language_model\n self.sil = -1 # index for silent token in the language model\n self.states = {}\n\n language_model.eval()\n\n def start(self, start_with_nothing: bool = False):\n state = CTCDecoderLMState()\n with torch.no_grad():\n score = self.language_model(self.sil)\n\n self.states[state] = score\n return state\n\n def score(self, state: CTCDecoderLMState, token_index: int):\n outstate = state.child(token_index)\n if outstate not in self.states:\n score = self.language_model(token_index)\n self.states[outstate] = score\n score = self.states[outstate]\n\n return outstate, score\n\n def finish(self, state: CTCDecoderLMState):\n return self.score(state, self.sil)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Downloading Pretrained Files\n\nPretrained files for the LibriSpeech dataset can be downloaded using\n:py:func:`~torchaudio.models.decoder.download_pretrained_files`.\n\nNote: this cell may take a couple of minutes to run, as the language\nmodel can be large\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchaudio.models.decoder import download_pretrained_files\n\nfiles = download_pretrained_files(\"librispeech-4-gram\")\n\nprint(files)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Construct Decoders\nIn this tutorial, we construct both a beam search decoder and a greedy decoder\nfor comparison.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Beam Search Decoder\nThe decoder can be constructed using the factory function\n:py:func:`~torchaudio.models.decoder.ctc_decoder`.\nIn addition to the previously mentioned components, it also takes in various beam\nsearch decoding parameters and token/word parameters.\n\nThis decoder can also be run without a language model by passing in `None` into the\n`lm` parameter.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "LM_WEIGHT = 3.23\nWORD_SCORE = -0.26\n\nbeam_search_decoder = ctc_decoder(\n lexicon=files.lexicon,\n tokens=files.tokens,\n lm=files.lm,\n nbest=3,\n beam_size=1500,\n lm_weight=LM_WEIGHT,\n word_score=WORD_SCORE,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Greedy Decoder\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class GreedyCTCDecoder(torch.nn.Module):\n def __init__(self, labels, blank=0):\n super().__init__()\n self.labels = labels\n self.blank = blank\n\n def forward(self, emission: torch.Tensor) -> List[str]:\n \"\"\"Given a sequence emission over labels, get the best path\n Args:\n emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.\n\n Returns:\n List[str]: The resulting transcript\n \"\"\"\n indices = torch.argmax(emission, dim=-1) # [num_seq,]\n indices = torch.unique_consecutive(indices, dim=-1)\n indices = [i for i in indices if i != self.blank]\n joined = \"\".join([self.labels[i] for i in indices])\n return joined.replace(\"|\", \" \").strip().split()\n\n\ngreedy_decoder = GreedyCTCDecoder(tokens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run Inference\n\nNow that we have the data, acoustic model, and decoder, we can perform\ninference. The output of the beam search decoder is of type\n:py:class:`~torchaudio.models.decoder.CTCHypothesis`, consisting of the\npredicted token IDs, corresponding words (if a lexicon is provided), hypothesis score,\nand timesteps corresponding to the token IDs. Recall the transcript corresponding to the\nwaveform is\n\n```\ni really was very much afraid of showing him how much shocked i was at some parts of what he said\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "actual_transcript = \"i really was very much afraid of showing him how much shocked i was at some parts of what he said\"\nactual_transcript = actual_transcript.split()\n\nemission, _ = acoustic_model(waveform)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The greedy decoder gives the following result.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "greedy_result = greedy_decoder(emission[0])\ngreedy_transcript = \" \".join(greedy_result)\ngreedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)\n\nprint(f\"Transcript: {greedy_transcript}\")\nprint(f\"WER: {greedy_wer}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using the beam search decoder:\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "beam_search_result = beam_search_decoder(emission)\nbeam_search_transcript = \" \".join(beam_search_result[0][0].words).strip()\nbeam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(\n actual_transcript\n)\n\nprint(f\"Transcript: {beam_search_transcript}\")\nprint(f\"WER: {beam_search_wer}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
The :py:attr:`~torchaudio.models.decoder.CTCHypothesis.words`\n field of the output hypotheses will be empty if no lexicon\n is provided to the decoder. To retrieve a transcript with lexicon-free\n decoding, you can perform the following to retrieve the token indices,\n convert them to original tokens, then join them together.\n\n .. code::\n\n tokens_str = \"\".join(beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens))\n transcript = \" \".join(tokens_str.split(\"|\"))