Back to Repositories

Validating GlowTTS Training Pipeline in Coqui-AI TTS

This test suite validates the training and inference functionality of the GlowTTS model within the Coqui-AI TTS framework. It covers model configuration, training initialization, checkpoint management, and inference capabilities.

Test Coverage Overview

The test suite provides comprehensive coverage of the GlowTTS training pipeline.

Key areas tested include:
  • Model configuration setup and validation
  • Training initialization and execution
  • Checkpoint management and restoration
  • Inference pipeline verification
  • Configuration persistence and loading

Implementation Analysis

The testing approach implements a complete training-inference cycle for the GlowTTS model. It utilizes a systematic pattern of configuration setup, model training, checkpoint verification, and inference validation. The implementation leverages CLI commands for training and inference, simulating real-world usage scenarios.

Technical patterns include:
  • Dynamic device selection for CUDA
  • File path handling and checkpoint management
  • Configuration persistence and validation
  • Training continuation from checkpoints

Technical Details

Testing tools and configuration:
  • LJSpeech dataset formatter
  • CUDA device management
  • JSON configuration handling
  • Phoneme cache implementation
  • Attention mask integration
  • Batch size optimization for training and evaluation
  • Audio preprocessing settings

Best Practices Demonstrated

The test implementation showcases several testing best practices for machine learning models.

Notable practices include:
  • Systematic configuration validation
  • Resource cleanup after test execution
  • Checkpoint integrity verification
  • End-to-end pipeline validation
  • Configuration persistence checking
  • Modular test structure with clear separation of concerns

coqui-ai/tts

tests/tts_tests2/test_glow_tts_train.py

            
import glob
import json
import os
import shutil

from trainer import get_last_checkpoint

from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.glow_tts_config import GlowTTSConfig

config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")


config = GlowTTSConfig(
    batch_size=2,
    eval_batch_size=8,
    num_loader_workers=0,
    num_eval_loader_workers=0,
    text_cleaner="english_cleaners",
    use_phonemes=True,
    phoneme_language="en-us",
    phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
    run_eval=True,
    test_delay_epochs=-1,
    epochs=1,
    print_step=1,
    print_eval=True,
    test_sentences=[
        "Be a voice, not an echo.",
    ],
    data_dep_init_steps=1.0,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.save_json(config_path)

# train the model for one epoch
command_train = (
    f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
    f"--coqpit.output_path {output_path} "
    "--coqpit.datasets.0.formatter ljspeech "
    "--coqpit.datasets.0.meta_file_train metadata.csv "
    "--coqpit.datasets.0.meta_file_val metadata.csv "
    "--coqpit.datasets.0.path tests/data/ljspeech "
    "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
    "--coqpit.test_delay_epochs 0"
)
run_cli(command_train)

# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)

# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")

# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
    config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0

# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)

# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)