{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Device AV-ASR with Emformer RNN-T\n\n**Author**: [Pingchuan Ma](pingchuanma@meta.com)_, [Moto\nHira](moto@meta.com)_.\n\nThis tutorial shows how to run on-device audio-visual speech recognition\n(AV-ASR, or AVSR) with TorchAudio on a streaming device input,\ni.e.\u00a0microphone on laptop. AV-ASR is the task of transcribing text from\naudio and visual streams, which has recently attracted a lot of research\nattention due to its robustness against noise.\n\n

Note

This tutorial requires ffmpeg, sentencepiece, mediapipe,\n opencv-python and scikit-image libraries.\n\n There are multiple ways to install ffmpeg libraries.\n If you are using Anaconda Python\n distribution, ``conda install -c conda-forge 'ffmpeg<7'`` will\n install compatible FFmpeg libraries.\n\n You can run\n ``pip install sentencepiece mediapipe opencv-python scikit-image`` to\n install the other libraries mentioned.

\n\n

Note

To run this tutorial, please make sure you are in the `tutorial` folder.

\n\n

Note

We tested the tutorial on torchaudio version 2.0.2 on Macbook Pro (M1 Pro).

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport sentencepiece as spm\nimport torch\nimport torchaudio\nimport torchvision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n\nThe real-time AV-ASR system is presented as follows, which consists of\nthree components, a data collection module, a pre-processing module and\nan end-to-end model. The data collection module is hardware, such as a\nmicrophone and camera. Its role is to collect information from the real\nworld. Once the information is collected, the pre-processing module\nlocation and crop out face. Next, we feed the raw audio stream and the\npre-processed video stream into our end-to-end model for inference.\n\n\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Data acquisition\n\nFirstly, we define the function to collect videos from microphone and\ncamera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`\nclass for the purpose of data collection, which supports capturing\naudio/video from microphone and camera. For the detailed usage of this\nclass, please refer to the\n[tutorial](./streamreader_basic_tutorial.html)_.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def stream(q, format, option, src, segment_length, sample_rate):\n print(\"Building StreamReader...\")\n streamer = torchaudio.io.StreamReader(src=src, format=format, option=option)\n streamer.add_basic_video_stream(frames_per_chunk=segment_length, buffer_chunk_size=500, width=600, height=340)\n streamer.add_basic_audio_stream(frames_per_chunk=segment_length * 640, sample_rate=sample_rate)\n\n print(streamer.get_src_stream_info(0))\n print(streamer.get_src_stream_info(1))\n print(\"Streaming...\")\n print()\n for (chunk_v, chunk_a) in streamer.stream(timeout=-1, backoff=1.0):\n q.put([chunk_v, chunk_a])\n\n\nclass ContextCacher:\n def __init__(self, segment_length: int, context_length: int, rate_ratio: int):\n self.segment_length = segment_length\n self.context_length = context_length\n\n self.context_length_v = context_length\n self.context_length_a = context_length * rate_ratio\n self.context_v = torch.zeros([self.context_length_v, 3, 340, 600])\n self.context_a = torch.zeros([self.context_length_a, 1])\n\n def __call__(self, chunk_v, chunk_a):\n if chunk_v.size(0) < self.segment_length:\n chunk_v = torch.nn.functional.pad(chunk_v, (0, 0, 0, 0, 0, 0, 0, self.segment_length - chunk_v.size(0)))\n if chunk_a.size(0) < self.segment_length * 640:\n chunk_a = torch.nn.functional.pad(chunk_a, (0, 0, 0, self.segment_length * 640 - chunk_a.size(0)))\n\n if self.context_length == 0:\n return chunk_v.float(), chunk_a.float()\n else:\n chunk_with_context_v = torch.cat((self.context_v, chunk_v))\n chunk_with_context_a = torch.cat((self.context_a, chunk_a))\n self.context_v = chunk_v[-self.context_length_v :]\n self.context_a = chunk_a[-self.context_length_a :]\n return chunk_with_context_v.float(), chunk_with_context_a.float()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Pre-processing\n\nBefore feeding the raw stream into our model, each video sequence has to\nundergo a specific pre-processing procedure. This involves three\ncritical steps. The first step is to perform face detection. Following\nthat, each individual frame is aligned to a referenced frame, commonly\nknown as the mean face, in order to normalize rotation and size\ndifferences across frames. The final step in the pre-processing module\nis to crop the face region from the aligned face image.\n\n.. list-table::\n :widths: 25 25 25 25\n :header-rows: 0\n\n * - .. image:: https://download.pytorch.org/torchaudio/doc-assets/avsr/original.gif\n - .. image:: https://download.pytorch.org/torchaudio/doc-assets/avsr/detected.gif\n - .. image:: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif\n - .. image:: https://download.pytorch.org/torchaudio/doc-assets/avsr/cropped.gif\n\n * - 0. Original\n - 1. Detected\n - 2. Transformed\n - 3. Cropped\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import sys\n\nsys.path.insert(0, \"../../examples\")\n\nfrom avsr.data_prep.detectors.mediapipe.detector import LandmarksDetector\nfrom avsr.data_prep.detectors.mediapipe.video_process import VideoProcess\n\n\nclass FunctionalModule(torch.nn.Module):\n def __init__(self, functional):\n super().__init__()\n self.functional = functional\n\n def forward(self, input):\n return self.functional(input)\n\n\nclass Preprocessing(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.landmarks_detector = LandmarksDetector()\n self.video_process = VideoProcess()\n self.video_transform = torch.nn.Sequential(\n FunctionalModule(\n lambda n: [(lambda x: torchvision.transforms.functional.resize(x, 44, antialias=True))(i) for i in n]\n ),\n FunctionalModule(lambda x: torch.stack(x)),\n torchvision.transforms.Normalize(0.0, 255.0),\n torchvision.transforms.Grayscale(),\n torchvision.transforms.Normalize(0.421, 0.165),\n )\n\n def forward(self, audio, video):\n video = video.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)\n landmarks = self.landmarks_detector(video)\n video = self.video_process(video, landmarks)\n video = torch.tensor(video).permute(0, 3, 1, 2).float()\n video = self.video_transform(video)\n audio = audio.mean(axis=-1, keepdim=True)\n return audio, video" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Building inference pipeline\n\nThe next step is to create components required for pipeline.\n\nWe use convolutional-based front-ends to extract features from both the\nraw audio and video streams. These features are then passed through a\ntwo-layer MLP for fusion. For our transducer model, we leverage the\nTorchAudio library, which incorporates an encoder (Emformer), a\npredictor, and a joint network. The architecture of the proposed AV-ASR\nmodel is illustrated as follows.\n\n\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class SentencePieceTokenProcessor:\n def __init__(self, sp_model):\n self.sp_model = sp_model\n self.post_process_remove_list = {\n self.sp_model.unk_id(),\n self.sp_model.eos_id(),\n self.sp_model.pad_id(),\n }\n\n def __call__(self, tokens, lstrip: bool = True) -> str:\n filtered_hypo_tokens = [\n token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list\n ]\n output_string = \"\".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace(\"\\u2581\", \" \")\n\n if lstrip:\n return output_string.lstrip()\n else:\n return output_string\n\n\nclass InferencePipeline(torch.nn.Module):\n def __init__(self, preprocessor, model, decoder, token_processor):\n super().__init__()\n self.preprocessor = preprocessor\n self.model = model\n self.decoder = decoder\n self.token_processor = token_processor\n\n self.state = None\n self.hypotheses = None\n\n def forward(self, audio, video):\n audio, video = self.preprocessor(audio, video)\n feats = self.model(audio.unsqueeze(0), video.unsqueeze(0))\n length = torch.tensor([feats.size(1)], device=audio.device)\n self.hypotheses, self.state = self.decoder.infer(feats, length, 10, state=self.state, hypothesis=self.hypotheses)\n transcript = self.token_processor(self.hypotheses[0][0], lstrip=False)\n return transcript\n\n\ndef _get_inference_pipeline(model_path, spm_model_path):\n model = torch.jit.load(model_path)\n model.eval()\n\n sp_model = spm.SentencePieceProcessor(model_file=spm_model_path)\n token_processor = SentencePieceTokenProcessor(sp_model)\n\n decoder = torchaudio.models.RNNTBeamSearch(model.model, sp_model.get_piece_size())\n\n return InferencePipeline(\n preprocessor=Preprocessing(),\n model=model,\n decoder=decoder,\n token_processor=token_processor,\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. The main process\n\nThe execution flow of the main process is as follows:\n\n1. Initialize the inference pipeline.\n2. Launch data acquisition subprocess.\n3. Run inference.\n4. Clean up\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchaudio.utils import download_asset\n\n\ndef main(device, src, option=None):\n print(\"Building pipeline...\")\n model_path = download_asset(\"tutorial-assets/device_avsr_model.pt\")\n spm_model_path = download_asset(\"tutorial-assets/spm_unigram_1023.model\")\n\n pipeline = _get_inference_pipeline(model_path, spm_model_path)\n\n BUFFER_SIZE = 32\n segment_length = 8\n context_length = 4\n sample_rate = 19200\n frame_rate = 30\n rate_ratio = sample_rate // frame_rate\n cacher = ContextCacher(BUFFER_SIZE, context_length, rate_ratio)\n\n import torch.multiprocessing as mp\n\n ctx = mp.get_context(\"spawn\")\n\n @torch.inference_mode()\n def infer():\n num_video_frames = 0\n video_chunks = []\n audio_chunks = []\n while True:\n chunk_v, chunk_a = q.get()\n num_video_frames += chunk_a.size(0) // 640\n video_chunks.append(chunk_v)\n audio_chunks.append(chunk_a)\n if num_video_frames < BUFFER_SIZE:\n continue\n video = torch.cat(video_chunks)\n audio = torch.cat(audio_chunks)\n video, audio = cacher(video, audio)\n pipeline.state, pipeline.hypotheses = None, None\n transcript = pipeline(audio, video.float())\n print(transcript, end=\"\", flush=True)\n num_video_frames = 0\n video_chunks = []\n audio_chunks = []\n\n q = ctx.Queue()\n p = ctx.Process(target=stream, args=(q, device, option, src, segment_length, sample_rate))\n p.start()\n infer()\n p.join()\n\n\nif __name__ == \"__main__\":\n main(\n device=\"avfoundation\",\n src=\"0:1\",\n option={\"framerate\": \"30\", \"pixel_format\": \"rgb24\"},\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ".. code::\n\n Building pipeline...\n Building StreamReader...\n SourceVideoStream(media_type='video', codec='rawvideo', codec_long_name='raw video', format='uyvy422', bit_rate=0, num_frames=0, bits_per_sample=0, metadata={}, width=1552, height=1552, frame_rate=1000000.0)\n SourceAudioStream(media_type='audio', codec='pcm_f32le', codec_long_name='PCM 32-bit floating point little-endian', format='flt', bit_rate=1536000, num_frames=0, bits_per_sample=0, metadata={}, sample_rate=48000.0, num_channels=1)\n Streaming...\n\n hello world\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Tag: :obj:`torchaudio.io`\n\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }