Skip to content

Convert and Infer Models on TensorFlowLite

Everything here can be followed along in Google Colab!

Colab

Installation and Setup

We need to first install TensorFlowTTS, which is a fork of the original repository developed by w11wo.

!git clone -q https://github.com/w11wo/TensorFlowTTS.git
!cd TensorFlowTTS && pip install -q . > /dev/null

Then, we'll need to downgrade TensorFlow to version 2.3.1. This is the version that I found to be working well all the way to mobile deployment. You could probably get away with the newer versions, nonetheless.

!pip install -q tensorflow-gpu==2.3.1

We also need to downgrade NumPy to the right version for this TensorFlow version.

!pip install -q numpy==1.20.3

IMPORTANT: after re-installing TensorFlow and NumPy, be sure to restart your Colab Runtime!

Log into HuggingFace Hub

If you have previously saved your model weights in HuggingFace Hub, it'll be immensely easier to load them back. Private Hub models can also be loaded, so long as you first log in to the Hub, which we'll do via notebook_login.

from huggingface_hub import notebook_login
notebook_login()
Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.huggingface/token
Login successful

To load a private Hub model, you just have to specify use_auth_token=True later.

Convert Models

Typically, converting to a TensorFlowLite model involves these steps:

  1. Loading the model weights
  2. Getting the concrete function of the model's forward pass
  3. Setting up the converter
  4. Specifying optimizations
  5. Convert and save TFLite model
import tensorflow as tf
from tensorflow_tts.inference import TFAutoModel
/usr/local/lib/python3.7/dist-packages/tensorflow_addons/utils/ensure_tf_install.py:68: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.2.0 and strictly below 2.3.0 (nightly versions are not supported). 
 The versions of TensorFlow you are currently using is 2.3.1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons
  UserWarning,
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Unzipping corpora/cmudict.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
def convert_text2mel_tflite(
    model_path: str, save_name: str, use_auth_token: bool = False
) -> float:
    """Convert text2mel model to TFLite.

    Args:
        model_path (str): Pretrained model checkpoint in HuggingFace Hub.
        save_name (str): TFLite file savename.
        use_auth_token (bool, optional): Use HF Hub Token. Defaults to False.

    Returns:
        float: Model size in Megabytes.
    """
    # load pretrained model
    model = TFAutoModel.from_pretrained(
        model_path, enable_tflite_convertible=True, use_auth_token=use_auth_token
    )

    # setup model concrete function
    concrete_function = model.inference_tflite.get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])

    # specify optimizations
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # quantize
        tf.lite.OpsSet.SELECT_TF_OPS,
    ]

    # convert and save model to TensorFlowLite
    tflite_model = converter.convert()
    with open(save_name, "wb") as f:
        f.write(tflite_model)

    size = len(tflite_model) / 1024 / 1024.0
    return size
def convert_vocoder_tflite(
    model_path: str, save_name: str, use_auth_token: bool = False
) -> float:
    """Convert vocoder model to TFLite.

    Args:
        model_path (str): Pretrained model checkpoint in HuggingFace Hub.
        save_name (str): TFLite file savename.
        use_auth_token (bool, optional): Use HF Hub Token. Defaults to False.

    Returns:
        float: Model size in Megabytes.
    """
    # load pretrained model
    model = TFAutoModel.from_pretrained(model_path, use_auth_token=use_auth_token)

    # setup model concrete function
    concrete_function = model.inference_tflite.get_concrete_function()
    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])

    # specify optimizations
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_ops = [tf.lite.OpsSet.SELECT_TF_OPS]
    converter.target_spec.supported_types = [tf.float16]  # fp16 ops

    # convert and save model to TensorFlowLite
    tflite_model = converter.convert()
    with open(save_name, "wb") as f:
        f.write(tflite_model)

    size = len(tflite_model) / 1024 / 1024.0
    return size
text2mel = convert_text2mel_tflite(
    model_path="bookbot/lightspeech-mfa-id-v3",
    save_name="lightspeech_quant.tflite",
    use_auth_token=True,
)

vocoder = convert_vocoder_tflite(
    model_path="bookbot/mb-melgan-hifi-postnets-id-v10",
    save_name="mbmelgan.tflite",
    use_auth_token=True,
)
/usr/local/lib/python3.7/dist-packages/huggingface_hub/file_download.py:595: FutureWarning: `cached_download` is the legacy way to download files from the HF hub, please consider upgrading to `hf_hub_download`
  FutureWarning,
Downloading:   0%|          | 0.00/19.5M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/1.89k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/10.2M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/2.53k [00:00<?, ?B/s]
print(f"Text2mel: {text2mel} MBs\nVocoder: {vocoder} MBs")
Text2mel: 4.323883056640625 MBs
Vocoder: 5.0258941650390625 MBs

Conversion Script

We also provide an script version of the conversion steps above, which can be found here. To use it, you just have to specify the arguments through the command line. An example is as follows:

!python TensorFlowTTS/examples/tensorflowlite/convert_tflite.py \
    --text2mel_path="bookbot/lightspeech-mfa-id-v3" \
    --text2mel_savename="lightspeech_quant.tflite" \
    --vocoder_path="bookbot/mb-melgan-hifi-postnets-id-v10" \
    --vocoder_savename="mbmelgan.tflite" \
    --use_auth_token

Inference

With the converted TFLite models, we can then perform inference on TFLite Runtime. Here, we'll only be presenting a way to perform inference for LightSpeech + Multi-band MelGAN. Other models might differ (e.g. FastSpeech2 has different model outputs compared to LightSpeech). However, adapting the inference code to other models should be fairly doable given that you know the outputs of each model.

Tokenization

To apply tokenization to our raw text, we can simply load the processor (tokenizer) we used during training. Again, if it's stored to the HuggingFace Hub, you can conveniently load it from it during inference.

You could optionally specify if it's private, and load it the same way as you would load a private Hub model.

from tensorflow_tts.inference import AutoProcessor

processor = AutoProcessor.from_pretrained("bookbot/lightspeech-mfa-id-v3", use_auth_token=True)
processor.mode = "eval" # change processor from train to eval mode
Downloading:   0%|          | 0.00/1.04k [00:00<?, ?B/s]

With the processor, we can then tokenize any input text and convert them to its correponding input ids (list of token IDs).

from typing import List, Tuple

def tokenize(text: str, processor: AutoProcessor) -> List[int]:
    """Tokenize text to input ids.

    Args:
        text (str): Input text to tokenize.
        processor (AutoProcessor): Processor for tokenization.

    Returns:
        List[int]: List of input (token) ids.
    """
    return processor.text_to_sequence(text)
text = "Halo, bagaimana kabar mu?"
input_ids = tokenize(text, processor)
input_ids
[8, 1, 12, 15, 32, 2, 1, 7, 1, 9, 13, 1, 14, 1, 11, 1, 2, 1, 17, 13, 20, 34]

Prepare LightSpeech Input

LightSpeech expects 5 inputs for inference, namely:

  1. Input IDs
  2. Speaker ID
  3. Speed Ratio
  4. Pitch Ratio
  5. Energy Ratio

Speaker ID is only relevant for a multi-speaker model, with each index (starting from 0) corresponding to different speaker embeddings for the text2mel model to use.

You can also alter other options such as speed, which serves like a duration multiplier (i.e. speed ratio of 2 is half the normal speed). You could also alter the pitch and energy in a similar way. For simplicity, I'll just be parameterizing the speaker ID.

def prepare_input(
    input_ids: List[str], speaker: int
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
    """Prepares input for LightSpeech TFLite inference.

    Args:
        input_ids (List[str]): Phoneme input ids according to processor.
        speaker (int): Speaker ID.

    Returns:
        Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
            Tuple of tensors consisting of:
                1. Input IDs
                2. Speaker ID
                3. Speed Ratio
                4. Pitch Ratio
                5. Energy Ratio
    """
    input_ids = tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0)
    return (
        input_ids,
        tf.convert_to_tensor([speaker], tf.int32),
        tf.convert_to_tensor([1.0], dtype=tf.float32),
        tf.convert_to_tensor([1.0], dtype=tf.float32),
        tf.convert_to_tensor([1.0], dtype=tf.float32),
    )

Inference

To perform inference on a TensorFlowLite Runtime, the general flow is as follows:

  1. Load model weights to TFLite Interpreter
  2. Resize interpreter input tensors according to actual input
  3. Allocate tensors according to Interpreter's inputs
  4. Set the actual tensor values as inputs to the Interpreter
  5. Invoke interpreter
  6. Return output tensors

We can follow the steps outlined above and create an inference function for LightSpeech and MB-MelGAN.

def ls_infer(
    input_ids: List[str], speaker: int, lightspeech_path: str
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
    """Performs LightSpeech inference.

    Args:
        input_ids (List[str]): Phoneme input ids according to processor.
        speaker (int): Speaker ID.
        lightspeech_path (str): Path to LightSpeech weights.

    Returns:
        Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
            Tuple of tensors consisting of:
                1. Mel-spectrogram output
                2. Durations array
    """
    # load model to Interpreter
    lightspeech = tf.lite.Interpreter(model_path=lightspeech_path)
    input_details = lightspeech.get_input_details()
    output_details = lightspeech.get_output_details()

    # resize input tensors according to actual shape
    lightspeech.resize_tensor_input(input_details[0]["index"], [1, len(input_ids)])
    lightspeech.resize_tensor_input(input_details[1]["index"], [1])
    lightspeech.resize_tensor_input(input_details[2]["index"], [1])
    lightspeech.resize_tensor_input(input_details[3]["index"], [1])
    lightspeech.resize_tensor_input(input_details[4]["index"], [1])

    # allocate tensors
    lightspeech.allocate_tensors()

    input_data = prepare_input(input_ids, speaker)

    # set input tensors
    for i, detail in enumerate(input_details):
        lightspeech.set_tensor(detail["index"], input_data[i])

    # invoke interpreter
    lightspeech.invoke()

    # return outputs
    return (
        lightspeech.get_tensor(output_details[0]["index"]),
        lightspeech.get_tensor(output_details[1]["index"]),
    )
def melgan_infer(melspectrogram: tf.Tensor, mb_melgan_path: str) -> tf.Tensor:
    """Performs MB-MelGAN inference.

    Args:
        melspectrogram (tf.Tensor): Mel-spectrogram to synthesize.
        mb_melgan_path (str): Path to MB-MelGAN weights.

    Returns:
        tf.Tensor: Synthesized audio output tensor.
    """
    # load model to Interpreter
    mb_melgan = tf.lite.Interpreter(model_path=mb_melgan_path)
    input_details = mb_melgan.get_input_details()
    output_details = mb_melgan.get_output_details()

    # resize input tensors according to actual shape
    mb_melgan.resize_tensor_input(
        input_details[0]["index"],
        [1, melspectrogram.shape[1], melspectrogram.shape[2]],
        strict=True,
    )

    # allocate tensors
    mb_melgan.allocate_tensors()

    # set input tensors
    mb_melgan.set_tensor(input_details[0]["index"], melspectrogram)

    # invoke interpreter
    mb_melgan.invoke()

    # return output
    return mb_melgan.get_tensor(output_details[0]["index"])

Finally, we can perform inference with the model weights which we have converted earlier and run an end-to-end inference.

mel_output_tflite, *_ = ls_infer(
    input_ids, speaker=0, lightspeech_path="lightspeech_quant.tflite"
)

audio_tflite = melgan_infer(mel_output_tflite, mb_melgan_path="mbmelgan.tflite")[
    0, :, 0
]

To listen to the synthesized output via Jupyter, we can directly pass the outputs to IPython's Audio widget, while specifying the sample rate of the audio.

from IPython.display import Audio

Audio(data=audio_tflite, rate=32000)

Alternatively, we can write the output audio tensors to file.

import soundfile as sf

sf.write("./audio.wav", audio_tflite, 32000, "PCM_16")
from IPython.display import Audio

Audio("audio.wav")