{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Forced Alignment with Wav2Vec2\n\n**Author**: [Moto Hira](moto@meta.com)_\n\nThis tutorial shows how to align transcript to speech with\n``torchaudio``, using CTC segmentation algorithm described in\n[CTC-Segmentation of Large Corpora for German End-to-end Speech\nRecognition](https://arxiv.org/abs/2007.09127)_.\n\n

Note

This tutorial was originally written to illustrate a usecase\n for Wav2Vec2 pretrained model.\n\n TorchAudio now has a set of APIs designed for forced alignment.\n The [CTC forced alignment API tutorial](./ctc_forced_alignment_api_tutorial.html)_ illustrates the\n usage of :py:func:`torchaudio.functional.forced_align`, which is\n the core API.\n\n If you are looking to align your corpus, we recommend to use\n :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which combines\n :py:func:`~torchaudio.functional.forced_align` and other support\n functions with pre-trained model specifically trained for\n forced-alignment. Please refer to the\n [Forced alignment for multilingual data](forced_alignment_for_multilingual_data_tutorial.html)_ which\n illustrates its usage.

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchaudio\n\nprint(torch.__version__)\nprint(torchaudio.__version__)\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n\nThe process of alignment looks like the following.\n\n1. Estimate the frame-wise label probability from audio waveform\n2. Generate the trellis matrix which represents the probability of\n labels aligned at time step.\n3. Find the most likely path from the trellis matrix.\n\nIn this example, we use ``torchaudio``\\ \u2019s ``Wav2Vec2`` model for\nacoustic feature extraction.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparation\n\nFirst we import the necessary packages, and fetch data that we work on.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from dataclasses import dataclass\n\nimport IPython\nimport matplotlib.pyplot as plt\n\ntorch.random.manual_seed(0)\n\nSPEECH_FILE = torchaudio.utils.download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate frame-wise label probability\n\nThe first step is to generate the label class porbability of each audio\nframe. We can use a Wav2Vec2 model that is trained for ASR. Here we use\n:py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.\n\n``torchaudio`` provides easy access to pretrained models with associated\nlabels.\n\n

Note

In the subsequent sections, we will compute the probability in\n log-domain to avoid numerical instability. For this purpose, we\n normalize the ``emission`` with :py:func:`torch.log_softmax`.

\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H\nmodel = bundle.get_model().to(device)\nlabels = bundle.get_labels()\nwith torch.inference_mode():\n waveform, _ = torchaudio.load(SPEECH_FILE)\n emissions, _ = model(waveform.to(device))\n emissions = torch.log_softmax(emissions, dim=-1)\n\nemission = emissions[0].cpu().detach()\n\nprint(labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot():\n fig, ax = plt.subplots()\n img = ax.imshow(emission.T)\n ax.set_title(\"Frame-wise class probability\")\n ax.set_xlabel(\"Time\")\n ax.set_ylabel(\"Labels\")\n fig.colorbar(img, ax=ax, shrink=0.6, location=\"bottom\")\n fig.tight_layout()\n\n\nplot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate alignment probability (trellis)\n\nFrom the emission matrix, next we generate the trellis which represents\nthe probability of transcript labels occur at each time frame.\n\nTrellis is 2D matrix with time axis and label axis. The label axis\nrepresents the transcript that we are aligning. In the following, we use\n$t$ to denote the index in time axis and $j$ to denote the\nindex in label axis. $c_j$ represents the label at label index\n$j$.\n\nTo generate, the probability of time step $t+1$, we look at the\ntrellis from time step $t$ and emission at time step $t+1$.\nThere are two path to reach to time step $t+1$ with label\n$c_{j+1}$. The first one is the case where the label was\n$c_{j+1}$ at $t$ and there was no label change from\n$t$ to $t+1$. The other case is where the label was\n$c_j$ at $t$ and it transitioned to the next label\n$c_{j+1}$ at $t+1$.\n\nThe follwoing diagram illustrates this transition.\n\n\n\nSince we are looking for the most likely transitions, we take the more\nlikely path for the value of $k_{(t+1, j+1)}$, that is\n\n$k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )$\n\nwhere $k$ represents is trellis matrix, and $p(t, c_j)$\nrepresents the probability of label $c_j$ at time step $t$.\n$repeat$ represents the blank token from CTC formulation. (For the\ndetail of CTC algorithm, please refer to the *Sequence Modeling with CTC*\n[[distill.pub](https://distill.pub/2017/ctc/)_])\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# We enclose the transcript with space tokens, which represent SOS and EOS.\ntranscript = \"|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|\"\ndictionary = {c: i for i, c in enumerate(labels)}\n\ntokens = [dictionary[c] for c in transcript]\nprint(list(zip(transcript, tokens)))\n\n\ndef get_trellis(emission, tokens, blank_id=0):\n num_frame = emission.size(0)\n num_tokens = len(tokens)\n\n trellis = torch.zeros((num_frame, num_tokens))\n trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)\n trellis[0, 1:] = -float(\"inf\")\n trellis[-num_tokens + 1 :, 0] = float(\"inf\")\n\n for t in range(num_frame - 1):\n trellis[t + 1, 1:] = torch.maximum(\n # Score for staying at the same token\n trellis[t, 1:] + emission[t, blank_id],\n # Score for changing to the next token\n trellis[t, :-1] + emission[t, tokens[1:]],\n )\n return trellis\n\n\ntrellis = get_trellis(emission, tokens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot():\n fig, ax = plt.subplots()\n img = ax.imshow(trellis.T, origin=\"lower\")\n ax.annotate(\"- Inf\", (trellis.size(1) / 5, trellis.size(1) / 1.5))\n ax.annotate(\"+ Inf\", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))\n fig.colorbar(img, ax=ax, shrink=0.6, location=\"bottom\")\n fig.tight_layout()\n\n\nplot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the above visualization, we can see that there is a trace of high\nprobability crossing the matrix diagonally.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Find the most likely path (backtracking)\n\nOnce the trellis is generated, we will traverse it following the\nelements with high probability.\n\nWe will start from the last label index with the time step of highest\nprobability, then, we traverse back in time, picking stay\n($c_j \\rightarrow c_j$) or transition\n($c_j \\rightarrow c_{j+1}$), based on the post-transition\nprobability $k_{t, j} p(t+1, c_{j+1})$ or\n$k_{t, j+1} p(t+1, repeat)$.\n\nTransition is done once the label reaches the beginning.\n\nThe trellis matrix is used for path-finding, but for the final\nprobability of each segment, we take the frame-wise probability from\nemission matrix.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@dataclass\nclass Point:\n token_index: int\n time_index: int\n score: float\n\n\ndef backtrack(trellis, emission, tokens, blank_id=0):\n t, j = trellis.size(0) - 1, trellis.size(1) - 1\n\n path = [Point(j, t, emission[t, blank_id].exp().item())]\n while j > 0:\n # Should not happen but just in case\n assert t > 0\n\n # 1. Figure out if the current position was stay or change\n # Frame-wise score of stay vs change\n p_stay = emission[t - 1, blank_id]\n p_change = emission[t - 1, tokens[j]]\n\n # Context-aware score for stay vs change\n stayed = trellis[t - 1, j] + p_stay\n changed = trellis[t - 1, j - 1] + p_change\n\n # Update position\n t -= 1\n if changed > stayed:\n j -= 1\n\n # Store the path with frame-wise probability.\n prob = (p_change if changed > stayed else p_stay).exp().item()\n path.append(Point(j, t, prob))\n\n # Now j == 0, which means, it reached the SoS.\n # Fill up the rest for the sake of visualization\n while t > 0:\n prob = emission[t - 1, blank_id].exp().item()\n path.append(Point(j, t - 1, prob))\n t -= 1\n\n return path[::-1]\n\n\npath = backtrack(trellis, emission, tokens)\nfor p in path:\n print(p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_trellis_with_path(trellis, path):\n # To plot trellis with path, we take advantage of 'nan' value\n trellis_with_path = trellis.clone()\n for _, p in enumerate(path):\n trellis_with_path[p.time_index, p.token_index] = float(\"nan\")\n plt.imshow(trellis_with_path.T, origin=\"lower\")\n plt.title(\"The path found by backtracking\")\n plt.tight_layout()\n\n\nplot_trellis_with_path(trellis, path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking good.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Segment the path\nNow this path contains repetations for the same labels, so\nlet\u2019s merge them to make it close to the original transcript.\n\nWhen merging the multiple path points, we simply take the average\nprobability for the merged segments.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Merge the labels\n@dataclass\nclass Segment:\n label: str\n start: int\n end: int\n score: float\n\n def __repr__(self):\n return f\"{self.label}\\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})\"\n\n @property\n def length(self):\n return self.end - self.start\n\n\ndef merge_repeats(path):\n i1, i2 = 0, 0\n segments = []\n while i1 < len(path):\n while i2 < len(path) and path[i1].token_index == path[i2].token_index:\n i2 += 1\n score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)\n segments.append(\n Segment(\n transcript[path[i1].token_index],\n path[i1].time_index,\n path[i2 - 1].time_index + 1,\n score,\n )\n )\n i1 = i2\n return segments\n\n\nsegments = merge_repeats(path)\nfor seg in segments:\n print(seg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_trellis_with_segments(trellis, segments, transcript):\n # To plot trellis with path, we take advantage of 'nan' value\n trellis_with_path = trellis.clone()\n for i, seg in enumerate(segments):\n if seg.label != \"|\":\n trellis_with_path[seg.start : seg.end, i] = float(\"nan\")\n\n fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)\n ax1.set_title(\"Path, label and probability for each label\")\n ax1.imshow(trellis_with_path.T, origin=\"lower\", aspect=\"auto\")\n\n for i, seg in enumerate(segments):\n if seg.label != \"|\":\n ax1.annotate(seg.label, (seg.start, i - 0.7), size=\"small\")\n ax1.annotate(f\"{seg.score:.2f}\", (seg.start, i + 3), size=\"small\")\n\n ax2.set_title(\"Label probability with and without repetation\")\n xs, hs, ws = [], [], []\n for seg in segments:\n if seg.label != \"|\":\n xs.append((seg.end + seg.start) / 2 + 0.4)\n hs.append(seg.score)\n ws.append(seg.end - seg.start)\n ax2.annotate(seg.label, (seg.start + 0.8, -0.07))\n ax2.bar(xs, hs, width=ws, color=\"gray\", alpha=0.5, edgecolor=\"black\")\n\n xs, hs = [], []\n for p in path:\n label = transcript[p.token_index]\n if label != \"|\":\n xs.append(p.time_index + 1)\n hs.append(p.score)\n\n ax2.bar(xs, hs, width=0.5, alpha=0.5)\n ax2.axhline(0, color=\"black\")\n ax2.grid(True, axis=\"y\")\n ax2.set_ylim(-0.1, 1.1)\n fig.tight_layout()\n\n\nplot_trellis_with_segments(trellis, segments, transcript)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looks good.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Merge the segments into words\nNow let\u2019s merge the words. The Wav2Vec2 model uses ``'|'``\nas the word boundary, so we merge the segments before each occurance of\n``'|'``.\n\nThen, finally, we segment the original audio into segmented audio and\nlisten to them to see if the segmentation is correct.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Merge words\ndef merge_words(segments, separator=\"|\"):\n words = []\n i1, i2 = 0, 0\n while i1 < len(segments):\n if i2 >= len(segments) or segments[i2].label == separator:\n if i1 != i2:\n segs = segments[i1:i2]\n word = \"\".join([seg.label for seg in segs])\n score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)\n words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))\n i1 = i2 + 1\n i2 = i1\n else:\n i2 += 1\n return words\n\n\nword_segments = merge_words(segments)\nfor word in word_segments:\n print(word)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_alignments(trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate):\n trellis_with_path = trellis.clone()\n for i, seg in enumerate(segments):\n if seg.label != \"|\":\n trellis_with_path[seg.start : seg.end, i] = float(\"nan\")\n\n fig, [ax1, ax2] = plt.subplots(2, 1)\n\n ax1.imshow(trellis_with_path.T, origin=\"lower\", aspect=\"auto\")\n ax1.set_facecolor(\"lightgray\")\n ax1.set_xticks([])\n ax1.set_yticks([])\n\n for word in word_segments:\n ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor=\"white\", facecolor=\"none\")\n\n for i, seg in enumerate(segments):\n if seg.label != \"|\":\n ax1.annotate(seg.label, (seg.start, i - 0.7), size=\"small\")\n ax1.annotate(f\"{seg.score:.2f}\", (seg.start, i + 3), size=\"small\")\n\n # The original waveform\n ratio = waveform.size(0) / sample_rate / trellis.size(0)\n ax2.specgram(waveform, Fs=sample_rate)\n for word in word_segments:\n x0 = ratio * word.start\n x1 = ratio * word.end\n ax2.axvspan(x0, x1, facecolor=\"none\", edgecolor=\"white\", hatch=\"/\")\n ax2.annotate(f\"{word.score:.2f}\", (x0, sample_rate * 0.51), annotation_clip=False)\n\n for seg in segments:\n if seg.label != \"|\":\n ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)\n ax2.set_xlabel(\"time [second]\")\n ax2.set_yticks([])\n fig.tight_layout()\n\n\nplot_alignments(\n trellis,\n segments,\n word_segments,\n waveform[0],\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Audio Samples\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def display_segment(i):\n ratio = waveform.size(1) / trellis.size(0)\n word = word_segments[i]\n x0 = int(ratio * word.start)\n x1 = int(ratio * word.end)\n print(f\"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec\")\n segment = waveform[:, x0:x1]\n return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Generate the audio for each segment\nprint(transcript)\nIPython.display.Audio(SPEECH_FILE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(7)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display_segment(8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n\nIn this tutorial, we looked how to use torchaudio\u2019s Wav2Vec2 model to\nperform CTC segmentation for forced alignment.\n\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }