{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# CTC forced alignment API tutorial\n\n**Author**: [Xiaohui Zhang](xiaohuizhang@meta.com)_, [Moto Hira](moto@meta.com)_\n\nThe forced alignment is a process to align transcript with speech.\nThis tutorial shows how to align transcripts to speech using\n:py:func:`torchaudio.functional.forced_align` which was developed along the work of\n[Scaling Speech Technology to 1,000+ Languages](https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/)_.\n\n:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA\nimplementations which are more performant than the vanilla Python\nimplementation above, and are more accurate.\nIt can also handle missing transcript with special ```` token.\n\nThere is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`,\nwhich wraps the pre/post-processing explained in this tutorial and makes it easy\nto run forced-alignments.\n[Forced alignment for multilingual data](./forced_alignment_for_multilingual_data_tutorial.html)_ uses this API to\nillustrate how to align non-English transcripts.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchaudio\n\nprint(torch.__version__)\nprint(torchaudio.__version__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import IPython\nimport matplotlib.pyplot as plt\n\nimport torchaudio.functional as F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we prepare the speech data and the transcript we area going\nto use.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "SPEECH_FILE = torchaudio.utils.download_asset(\"tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav\")\nwaveform, _ = torchaudio.load(SPEECH_FILE)\nTRANSCRIPT = \"i had that curiosity beside me at this moment\".split()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating emissions\n\n:py:func:`~torchaudio.functional.forced_align` takes emission and\ntoken sequences and outputs timestaps of the tokens and their scores.\n\nEmission reperesents the frame-wise probability distribution over\ntokens, and it can be obtained by passing waveform to an acoustic\nmodel.\n\nTokens are numerical expression of transcripts. There are many ways to\ntokenize transcripts, but here, we simply map alphabets into integer,\nwhich is how labels were constructed when the acoustice model we are\ngoing to use was trained.\n\nWe will use a pre-trained Wav2Vec2 model,\n:py:data:`torchaudio.pipelines.MMS_FA`, to obtain emission and tokenize\nthe transcript.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "bundle = torchaudio.pipelines.MMS_FA\n\nmodel = bundle.get_model(with_star=False).to(device)\nwith torch.inference_mode():\n emission, _ = model(waveform.to(device))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_emission(emission):\n fig, ax = plt.subplots()\n ax.imshow(emission.cpu().T)\n ax.set_title(\"Frame-wise class probabilities\")\n ax.set_xlabel(\"Time\")\n ax.set_ylabel(\"Labels\")\n fig.tight_layout()\n\n\nplot_emission(emission[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenize the transcript\n\nWe create a dictionary, which maps each label into token.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "LABELS = bundle.get_labels(star=None)\nDICTIONARY = bundle.get_dict(star=None)\nfor k, v in DICTIONARY.items():\n print(f\"{k}: {v}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "converting transcript to tokens is as simple as\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]\n\nfor t in tokenized_transcript:\n print(t, end=\" \")\nprint()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing alignments\n\n### Frame-level alignments\n\nNow we call TorchAudio\u2019s forced alignment API to compute the\nframe-level alignment. For the detail of function signature, please\nrefer to :py:func:`~torchaudio.functional.forced_align`.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def align(emission, tokens):\n targets = torch.tensor([tokens], dtype=torch.int32, device=device)\n alignments, scores = F.forced_align(emission, targets, blank=0)\n\n alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity\n scores = scores.exp() # convert back to probability\n return alignments, scores\n\n\naligned_tokens, alignment_scores = align(emission, tokenized_transcript)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's look at the output.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):\n print(f\"{i:3d}:\\t{ali:2d} [{LABELS[ali]}], {score:.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The alignment is expressed in the frame cordinate of the emission,\n which is different from the original waveform.

\n\nIt contains blank tokens and repeated tokens. The following is the\ninterpretation of the non-blank tokens.\n\n```\n31: 0 [-], 1.00\n32: 2 [i], 1.00 \"i\" starts and ends\n33: 0 [-], 1.00\n34: 0 [-], 1.00\n35: 15 [h], 1.00 \"h\" starts\n36: 15 [h], 0.93 \"h\" ends\n37: 1 [a], 1.00 \"a\" starts and ends\n38: 0 [-], 0.96\n39: 0 [-], 1.00\n40: 0 [-], 1.00\n41: 13 [d], 1.00 \"d\" starts and ends\n42: 0 [-], 1.00\n```\n

Note

When same token occured after blank tokens, it is not treated as\n a repeat, but as a new occurrence.\n\n```\na a a b -> a b\na - - b -> a b\na a - b -> a b\na - a b -> a a b\n ^^^ ^^^

\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Token-level alignments\n\nNext step is to resolve the repetation, so that each alignment does\nnot depend on previous alignments.\n:py:func:`torchaudio.functional.merge_tokens` computes the\n:py:class:`~torchaudio.functional.TokenSpan` object, which represents\nwhich token from the transcript is present at what time span.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "token_spans = F.merge_tokens(aligned_tokens, alignment_scores)\n\nprint(\"Token\\tTime\\tScore\")\nfor s in token_spans:\n print(f\"{LABELS[s.token]}\\t[{s.start:3d}, {s.end:3d})\\t{s.score:.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Word-level alignments\n\nNow let\u2019s group the token-level alignments into word-level alignments.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def unflatten(list_, lengths):\n assert len(list_) == sum(lengths)\n i = 0\n ret = []\n for l in lengths:\n ret.append(list_[i : i + l])\n i += l\n return ret\n\n\nword_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Audio previews\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute average score weighted by the span length\ndef _score(spans):\n return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)\n\n\ndef preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):\n ratio = waveform.size(1) / num_frames\n x0 = int(ratio * spans[0].start)\n x1 = int(ratio * spans[-1].end)\n print(f\"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec\")\n segment = waveform[:, x0:x1]\n return IPython.display.Audio(segment.numpy(), rate=sample_rate)\n\n\nnum_frames = emission.size(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Generate the audio for each segment\nprint(TRANSCRIPT)\nIPython.display.Audio(SPEECH_FILE)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualization\n\nNow let's look at the alignment result and segment the original\nspeech into words.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):\n ratio = waveform.size(1) / emission.size(1) / sample_rate\n\n fig, axes = plt.subplots(2, 1)\n axes[0].imshow(emission[0].detach().cpu().T, aspect=\"auto\")\n axes[0].set_title(\"Emission\")\n axes[0].set_xticks([])\n\n axes[1].specgram(waveform[0], Fs=sample_rate)\n for t_spans, chars in zip(token_spans, transcript):\n t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1\n axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor=\"None\", hatch=\"/\", edgecolor=\"white\")\n axes[1].axvspan(ratio * t0, ratio * t1, facecolor=\"None\", hatch=\"/\", edgecolor=\"white\")\n axes[1].annotate(f\"{_score(t_spans):.2f}\", (ratio * t0, sample_rate * 0.51), annotation_clip=False)\n\n for span, char in zip(t_spans, chars):\n t0 = span.start * ratio\n axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)\n\n axes[1].set_xlabel(\"time [second]\")\n axes[1].set_xlim([0, None])\n fig.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plot_alignments(waveform, word_spans, emission, TRANSCRIPT)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inconsistent treatment of ``blank`` token\n\nWhen splitting the token-level alignments into words, you will\nnotice that some blank tokens are treated differently, and this makes\nthe interpretation of the result somehwat ambigious.\n\nThis is easy to see when we plot the scores. The following figure\nshows word regions and non-word regions, with the frame-level scores\nof non-blank tokens.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def plot_scores(word_spans, scores):\n fig, ax = plt.subplots()\n span_xs, span_hs = [], []\n ax.axvspan(word_spans[0][0].start - 0.05, word_spans[-1][-1].end + 0.05, facecolor=\"paleturquoise\", edgecolor=\"none\", zorder=-1)\n for t_span in word_spans:\n for span in t_span:\n for t in range(span.start, span.end):\n span_xs.append(t + 0.5)\n span_hs.append(scores[t].item())\n ax.annotate(LABELS[span.token], (span.start, -0.07))\n ax.axvspan(t_span[0].start - 0.05, t_span[-1].end + 0.05, facecolor=\"mistyrose\", edgecolor=\"none\", zorder=-1)\n ax.bar(span_xs, span_hs, color=\"lightsalmon\", edgecolor=\"coral\")\n ax.set_title(\"Frame-level scores and word segments\")\n ax.set_ylim(-0.1, None)\n ax.grid(True, axis=\"y\")\n ax.axhline(0, color=\"black\")\n fig.tight_layout()\n\n\nplot_scores(word_spans, alignment_scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this plot, the blank tokens are those highlighted area without\nvertical bar.\nYou can see that there are blank tokens which are interpreted as\npart of a word (highlighted red), while the others (highlighted blue)\nare not.\n\nOne reason for this is because the model was trained without a\nlabel for the word boundary. The blank tokens are treated not just\nas repeatation but also as silence between words.\n\nBut then, a question arises. Should frames immediately after or\nnear the end of a word be silent or repeat?\n\nIn the above example, if you go back to the previous plot of\nspectrogram and word regions, you see that after \"y\" in \"curiosity\",\nthere is still some activities in multiple frequency buckets.\n\nWould it be more accurate if that frame was included in the word?\n\nUnfortunately, CTC does not provide a comprehensive solution to this.\nModels trained with CTC are known to exhibit \"peaky\" response,\nthat is, they tend to spike for an aoccurance of a label, but the\nspike does not last for the duration of the label.\n(Note: Pre-trained Wav2Vec2 models tend to spike at the beginning of\nlabel occurances, but this not always the case.)\n\n:cite:`zeyer2021does` has in-depth alanysis on the peaky behavior of\nCTC.\nWe encourage those who are interested understanding more to refer\nto the paper.\nThe following is a quote from the paper, which is the exact issue we\nare facing here.\n\n *Peaky behavior can be problematic in certain cases,*\n *e.g. when an application requires to not use the blank label,*\n *e.g. to get meaningful time accurate alignments of phonemes*\n *to a transcription.*\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Advanced: Handling transcripts with ```` token\n\nNow let\u2019s look at when the transcript is partially missing, how can we\nimprove alignment quality using the ```` token, which is capable of modeling\nany token.\n\nHere we use the same English example as used above. But we remove the\nbeginning text ``\u201ci had that curiosity beside me at\u201d`` from the transcript.\nAligning audio with such transcript results in wrong alignments of the\nexisting word \u201cthis\u201d. However, this issue can be mitigated by using the\n```` token to model the missing text.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we extend the dictionary to include the ```` token.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "DICTIONARY[\"*\"] = len(DICTIONARY)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we extend the emission tensor with the extra dimension\ncorresponding to the ```` token.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)\nemission = torch.cat((emission, star_dim), 2)\n\nassert len(DICTIONARY) == emission.shape[2]\n\nplot_emission(emission[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following function combines all the processes, and compute\nword segments from emission in one-go.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def compute_alignments(emission, transcript, dictionary):\n tokens = [dictionary[char] for word in transcript for char in word]\n alignment, scores = align(emission, tokens)\n token_spans = F.merge_tokens(alignment, scores)\n word_spans = unflatten(token_spans, [len(word) for word in transcript])\n return word_spans" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Full Transcript\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)\nplot_alignments(waveform, word_spans, emission, TRANSCRIPT)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Partial Transcript with ```` token\n\nNow we replace the first part of the transcript with the ```` token.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "transcript = \"* this moment\".split()\nword_spans = compute_alignments(emission, transcript, DICTIONARY)\nplot_alignments(waveform, word_spans, emission, transcript)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[0], num_frames, transcript[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[1], num_frames, transcript[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "preview_word(waveform, word_spans[2], num_frames, transcript[2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Partial Transcript without ```` token\n\nAs a comparison, the following aligns the partial transcript\nwithout using ```` token.\nIt demonstrates the effect of ```` token for dealing with deletion errors.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "transcript = \"this moment\".split()\nword_spans = compute_alignments(emission, transcript, DICTIONARY)\nplot_alignments(waveform, word_spans, emission, transcript)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n\nIn this tutorial, we looked at how to use torchaudio\u2019s forced alignment\nAPI to align and segment speech files, and demonstrated one advanced usage:\nHow introducing a ```` token could improve alignment accuracy when\ntranscription errors exist.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Acknowledgement\n\nThanks to [Vineel Pratap](vineelkpratap@meta.com)_ and [Zhaoheng\nNi](zni@meta.com)_ for developing and open-sourcing the\nforced aligner API.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }