Back to Repositories

Testing DelightfulTTS Acoustic Model and HiFi Decoder Implementation in Coqui-AI TTS

This test suite validates the functionality of DelightfulTTS acoustic model and HiFi decoder components in the Coqui-AI TTS system. It focuses on testing the neural network architectures and signal processing capabilities essential for text-to-speech synthesis.

Test Coverage Overview

The test suite provides comprehensive coverage of two critical TTS components:
  • Acoustic model testing with dummy tensors for tokens, spectrograms, pitch, and energy
  • HiFi decoder validation for waveform generation
  • Shape verification of output tensors
  • Integration testing between tokenizer and acoustic model components

Implementation Analysis

The testing approach employs PyTorch’s tensor operations and device management for GPU/CPU compatibility. It implements mock data generation patterns using torch.rand() for simulating input features and validates tensor shapes and model forward passes.

The tests utilize configuration objects (DelightfulTTSConfig, VocoderConfig) to maintain consistent model parameters across test cases.

Technical Details

  • Testing Framework: Python unittest
  • Core Dependencies: PyTorch, TTS library
  • Test Environment: Configurable for CPU/CUDA execution
  • Key Components: TTSTokenizer, AcousticModel, HifiganGenerator
  • Data Structures: PyTorch tensors with specific shapes for audio processing

Best Practices Demonstrated

The test suite exemplifies robust testing practices through systematic validation of model components.

Notable practices include:
  • Isolated component testing
  • Comprehensive input/output validation
  • Hardware-agnostic implementation
  • Configurable model parameters
  • Clear test case organization

coqui-ai/tts

tests/tts_tests2/test_delightful_tts_layers.py

            
import torch

from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
from TTS.tts.utils.helpers import rand_segments
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.vocoder.models.hifigan_generator import HifiganGenerator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

args = DelightfulTtsArgs()
v_args = VocoderConfig()


config = DelightfulTTSConfig(
    model_args=args,
    # compute_f0=True,
    # f0_cache_path=os.path.join(output_path, "f0_cache"),
    text_cleaner="english_cleaners",
    use_phonemes=True,
    phoneme_language="en-us",
    # phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
)

tokenizer, config = TTSTokenizer.init_from_config(config)


def test_acoustic_model():
    dummy_tokens = torch.rand((1, 41)).long().to(device)
    dummy_text_lens = torch.tensor([41]).long().to(device)
    dummy_spec = torch.rand((1, 100, 207)).to(device)
    dummy_spec_lens = torch.tensor([207]).to(device)
    dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
    dummy_energy = torch.rand((1, 1, 207)).long().to(device)

    args.out_channels = 100
    args.num_mels = 100

    acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device)
    acoustic_model = acoustic_model.train()

    output = acoustic_model(
        tokens=dummy_tokens,
        src_lens=dummy_text_lens,
        mel_lens=dummy_spec_lens,
        mels=dummy_spec,
        pitches=dummy_pitch,
        energies=dummy_energy,
        attn_priors=None,
        d_vectors=None,
        speaker_idx=None,
    )
    assert list(output["model_outputs"].shape) == [1, 207, 100]
    # output["model_outputs"].sum().backward()


def test_hifi_decoder():
    dummy_input = torch.rand((1, 207, 100)).to(device)
    dummy_spec_lens = torch.tensor([207]).to(device)

    waveform_decoder = HifiganGenerator(
        100,
        1,
        v_args.resblock_type_decoder,
        v_args.resblock_dilation_sizes_decoder,
        v_args.resblock_kernel_sizes_decoder,
        v_args.upsample_kernel_sizes_decoder,
        v_args.upsample_initial_channel_decoder,
        v_args.upsample_rates_decoder,
        inference_padding=0,
        cond_channels=0,
        conv_pre_weight_norm=False,
        conv_post_weight_norm=False,
        conv_post_bias=False,
    ).to(device)
    waveform_decoder = waveform_decoder.train()

    vocoder_input_slices, slice_ids = rand_segments(  # pylint: disable=unused-variable
        x=dummy_input.transpose(1, 2),
        x_lengths=dummy_spec_lens,
        segment_size=32,
        let_short_samples=True,
        pad_short=True,
    )

    outputs = waveform_decoder(x=vocoder_input_slices.detach())
    assert list(outputs.shape) == [1, 1, 8192]
    # outputs.sum().backward()