Back to Repositories

Testing Tacotron2 Speaker Embedding Training Pipeline in Coqui-AI TTS

This test suite validates the training and inference functionality of Tacotron2 with speaker embeddings in the Coqui-AI TTS framework. It covers model configuration, training initialization, checkpoint management, and inference with speaker embeddings.

Test Coverage Overview

The test suite provides comprehensive coverage of Tacotron2’s speaker embedding capabilities.

Key areas tested include:
  • Model configuration with speaker embedding parameters
  • Training initialization and execution
  • Checkpoint saving and loading
  • Inference with speaker ID specification
  • Config integrity verification

Implementation Analysis

The testing approach implements an end-to-end validation of the Tacotron2 training pipeline with speaker embeddings. It utilizes CLI commands to simulate real-world usage patterns, configuring the model with 4 speakers and specific audio parameters. The test verifies both training continuation and inference capabilities using saved checkpoints.

Technical Details

Key technical components include:
  • CUDA device management for GPU testing
  • Tacotron2Config configuration with speaker embedding parameters
  • LJSpeech dataset integration
  • Checkpoint management utilities
  • CLI-based training and inference commands
  • JSON configuration validation

Best Practices Demonstrated

The test exemplifies robust testing practices by validating both training and inference workflows in a single test case. It demonstrates proper resource cleanup, configuration verification, and systematic testing of model checkpointing. The test ensures speaker embedding functionality works correctly across the entire pipeline.

coqui-ai/tts

tests/tts_tests/test_tacotron2_speaker_emb_train.py

            
import glob
import json
import os
import shutil

from trainer import get_last_checkpoint

from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.tacotron2_config import Tacotron2Config

config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")

config = Tacotron2Config(
    r=5,
    batch_size=8,
    eval_batch_size=8,
    num_loader_workers=0,
    num_eval_loader_workers=0,
    text_cleaner="english_cleaners",
    use_phonemes=False,
    phoneme_language="en-us",
    phoneme_cache_path=os.path.join(get_tests_output_path(), "train_outputs/phoneme_cache/"),
    run_eval=True,
    test_delay_epochs=-1,
    epochs=1,
    print_step=1,
    print_eval=True,
    test_sentences=[
        "Be a voice, not an echo.",
    ],
    use_speaker_embedding=True,
    num_speakers=4,
    max_decoder_steps=50,
)

config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)

# train the model for one epoch
command_train = (
    f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
    f"--coqpit.output_path {output_path} "
    "--coqpit.datasets.0.formatter ljspeech_test "
    "--coqpit.datasets.0.meta_file_train metadata.csv "
    "--coqpit.datasets.0.meta_file_val metadata.csv "
    "--coqpit.datasets.0.path tests/data/ljspeech "
    "--coqpit.test_delay_epochs 0 "
)
run_cli(command_train)

# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)

# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")

# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
    config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0

# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)

# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)