Back to Repositories

Testing VITS Model Implementation in Coqui-TTS

This test suite validates the VITS (Conditional Variational Autoencoder with Adversarial Learning) model implementation in the Coqui TTS framework. It covers core functionality including audio processing, model initialization, training, inference and voice conversion capabilities.

Test Coverage Overview

The test suite provides comprehensive coverage of VITS model functionality including:

  • Audio processing utilities (loading, spectrogram conversion, mel scaling)
  • Model initialization for single/multi-speaker and multilingual scenarios
  • Forward pass and inference workflows
  • Voice conversion capabilities
  • Training and evaluation procedures
  • Integration with speaker encoders and language embeddings

Implementation Analysis

The testing approach uses unittest framework with systematic validation of:

  • Model architecture components and configurations
  • Input/output tensor shapes and properties
  • Training workflows with optimizer and loss computation
  • Parameter updates during training
  • Checkpoint loading/saving

Technical Details

  • Testing Framework: Python unittest
  • Key Dependencies: PyTorch, TensorBoard
  • Test Data: Sample audio files, speaker embeddings, language IDs
  • Device Support: CPU and CUDA
  • Configuration: Model params, audio settings, training args

Best Practices Demonstrated

The test suite showcases quality testing practices including:

  • Comprehensive assertion checks for model behavior
  • Isolation of components for targeted testing
  • Validation of both training and inference paths
  • Edge case handling for different model configurations
  • Integration testing with external components

coqui-ai/tts

tests/tts_tests/test_vits.py

            
import copy
import os
import unittest

import torch
from trainer.logging.tensorboard_logger import TensorboardLogger

from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path
from TTS.config import load_config
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models.vits import (
    Vits,
    VitsArgs,
    VitsAudioConfig,
    amp_to_db,
    db_to_amp,
    load_audio,
    spec_to_mel,
    wav_to_mel,
    wav_to_spec,
)
from TTS.tts.utils.speakers import SpeakerManager

LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json")
SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json")
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")


torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# pylint: disable=no-self-use
class TestVits(unittest.TestCase):
    def test_load_audio(self):
        wav, sr = load_audio(WAV_FILE)
        self.assertEqual(wav.shape, (1, 41885))
        self.assertEqual(sr, 22050)

        spec = wav_to_spec(wav, n_fft=1024, hop_length=512, win_length=1024, center=False)
        mel = wav_to_mel(
            wav,
            n_fft=1024,
            num_mels=80,
            sample_rate=sr,
            hop_length=512,
            win_length=1024,
            fmin=0,
            fmax=8000,
            center=False,
        )
        mel2 = spec_to_mel(spec, n_fft=1024, num_mels=80, sample_rate=sr, fmin=0, fmax=8000)

        self.assertEqual((mel - mel2).abs().max(), 0)
        self.assertEqual(spec.shape[0], mel.shape[0])
        self.assertEqual(spec.shape[2], mel.shape[2])

        spec_db = amp_to_db(spec)
        spec_amp = db_to_amp(spec_db)

        self.assertAlmostEqual((spec - spec_amp).abs().max(), 0, delta=1e-4)

    def test_dataset(self):
        """TODO:"""
        ...

    def test_init_multispeaker(self):
        num_speakers = 10
        args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
        model = Vits(args)
        assertHasAttr(self, model, "emb_g")

        args = VitsArgs(num_speakers=0, use_speaker_embedding=True)
        model = Vits(args)
        assertHasNotAttr(self, model, "emb_g")

        args = VitsArgs(num_speakers=10, use_speaker_embedding=False)
        model = Vits(args)
        assertHasNotAttr(self, model, "emb_g")

        args = VitsArgs(d_vector_dim=101, use_d_vector_file=True)
        model = Vits(args)
        self.assertEqual(model.embedded_speaker_dim, 101)

    def test_init_multilingual(self):
        args = VitsArgs(language_ids_file=None, use_language_embedding=False)
        model = Vits(args)
        self.assertEqual(model.language_manager, None)
        self.assertEqual(model.embedded_language_dim, 0)
        assertHasNotAttr(self, model, "emb_l")

        args = VitsArgs(language_ids_file=LANG_FILE)
        model = Vits(args)
        self.assertNotEqual(model.language_manager, None)
        self.assertEqual(model.embedded_language_dim, 0)
        assertHasNotAttr(self, model, "emb_l")

        args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True)
        model = Vits(args)
        self.assertNotEqual(model.language_manager, None)
        self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
        assertHasAttr(self, model, "emb_l")

        args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, embedded_language_dim=102)
        model = Vits(args)
        self.assertNotEqual(model.language_manager, None)
        self.assertEqual(model.embedded_language_dim, args.embedded_language_dim)
        assertHasAttr(self, model, "emb_l")

    def test_get_aux_input(self):
        aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None}
        args = VitsArgs()
        model = Vits(args)
        aux_out = model.get_aux_input(aux_input)

        speaker_id = torch.randint(10, (1,))
        language_id = torch.randint(10, (1,))
        d_vector = torch.rand(1, 128)
        aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id}
        aux_out = model.get_aux_input(aux_input)
        self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape)
        self.assertEqual(aux_out["language_ids"].shape, language_id.shape)
        self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape)

    def test_voice_conversion(self):
        num_speakers = 10
        spec_len = 101
        spec_effective_len = 50

        args = VitsArgs(num_speakers=num_speakers, use_speaker_embedding=True)
        model = Vits(args)

        ref_inp = torch.randn(1, 513, spec_len)
        ref_inp_len = torch.randint(1, spec_effective_len, (1,))
        ref_spk_id = torch.randint(1, num_speakers, (1,)).item()
        tgt_spk_id = torch.randint(1, num_speakers, (1,)).item()
        o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id)

        self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))
        self.assertEqual(y_mask.shape, (1, 1, spec_len))
        self.assertEqual(y_mask.sum(), ref_inp_len[0])
        self.assertEqual(z.shape, (1, args.hidden_channels, spec_len))
        self.assertEqual(z_p.shape, (1, args.hidden_channels, spec_len))
        self.assertEqual(z_hat.shape, (1, args.hidden_channels, spec_len))

    def _create_inputs(self, config, batch_size=2):
        input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
        input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
        input_lengths[-1] = 128
        spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device)
        mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device)
        spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
        spec_lengths[-1] = spec.size(2)
        waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device)
        return input_dummy, input_lengths, mel, spec, spec_lengths, waveform

    def _check_forward_outputs(self, config, output_dict, encoder_config=None, batch_size=2):
        self.assertEqual(
            output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
        )
        self.assertEqual(output_dict["alignments"].shape, (batch_size, 128, 30))
        self.assertEqual(output_dict["alignments"].max(), 1)
        self.assertEqual(output_dict["alignments"].min(), 0)
        self.assertEqual(output_dict["z"].shape, (batch_size, config.model_args.hidden_channels, 30))
        self.assertEqual(output_dict["z_p"].shape, (batch_size, config.model_args.hidden_channels, 30))
        self.assertEqual(output_dict["m_p"].shape, (batch_size, config.model_args.hidden_channels, 30))
        self.assertEqual(output_dict["logs_p"].shape, (batch_size, config.model_args.hidden_channels, 30))
        self.assertEqual(output_dict["m_q"].shape, (batch_size, config.model_args.hidden_channels, 30))
        self.assertEqual(output_dict["logs_q"].shape, (batch_size, config.model_args.hidden_channels, 30))
        self.assertEqual(
            output_dict["waveform_seg"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"]
        )
        if encoder_config:
            self.assertEqual(output_dict["gt_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
            self.assertEqual(output_dict["syn_spk_emb"].shape, (batch_size, encoder_config.model_params["proj_dim"]))
        else:
            self.assertEqual(output_dict["gt_spk_emb"], None)
            self.assertEqual(output_dict["syn_spk_emb"], None)

    def test_forward(self):
        num_speakers = 0
        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
        config.model_args.spec_segment_size = 10
        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
        model = Vits(config).to(device)
        output_dict = model.forward(input_dummy, input_lengths, spec, spec_lengths, waveform)
        self._check_forward_outputs(config, output_dict)

    def test_multispeaker_forward(self):
        num_speakers = 10

        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
        config.model_args.spec_segment_size = 10

        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config)
        speaker_ids = torch.randint(0, num_speakers, (8,)).long().to(device)

        model = Vits(config).to(device)
        output_dict = model.forward(
            input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"speaker_ids": speaker_ids}
        )
        self._check_forward_outputs(config, output_dict)

    def test_d_vector_forward(self):
        batch_size = 2
        args = VitsArgs(
            spec_segment_size=10,
            num_chars=32,
            use_d_vector_file=True,
            d_vector_dim=256,
            d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
        )
        config = VitsConfig(model_args=args)
        model = Vits.init_from_config(config, verbose=False).to(device)
        model.train()
        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
        d_vectors = torch.randn(batch_size, 256).to(device)
        output_dict = model.forward(
            input_dummy, input_lengths, spec, spec_lengths, waveform, aux_input={"d_vectors": d_vectors}
        )
        self._check_forward_outputs(config, output_dict)

    def test_multilingual_forward(self):
        num_speakers = 10
        num_langs = 3
        batch_size = 2

        args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)

        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
        lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)

        model = Vits(config).to(device)
        output_dict = model.forward(
            input_dummy,
            input_lengths,
            spec,
            spec_lengths,
            waveform,
            aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
        )
        self._check_forward_outputs(config, output_dict)

    def test_secl_forward(self):
        num_speakers = 10
        num_langs = 3
        batch_size = 2

        speaker_encoder_config = load_config(SPEAKER_ENCODER_CONFIG)
        speaker_encoder_config.model_params["use_torch_spec"] = True
        speaker_encoder = setup_encoder_model(speaker_encoder_config).to(device)
        speaker_manager = SpeakerManager()
        speaker_manager.encoder = speaker_encoder

        args = VitsArgs(
            language_ids_file=LANG_FILE,
            use_language_embedding=True,
            spec_segment_size=10,
            use_speaker_encoder_as_loss=True,
        )
        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
        config.audio.sample_rate = 16000

        input_dummy, input_lengths, _, spec, spec_lengths, waveform = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
        lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)

        model = Vits(config, speaker_manager=speaker_manager).to(device)
        output_dict = model.forward(
            input_dummy,
            input_lengths,
            spec,
            spec_lengths,
            waveform,
            aux_input={"speaker_ids": speaker_ids, "language_ids": lang_ids},
        )
        self._check_forward_outputs(config, output_dict, speaker_encoder_config)

    def _check_inference_outputs(self, config, outputs, input_dummy, batch_size=1):
        feat_len = outputs["z"].shape[2]
        self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1))  # we don't know the channel dimension
        self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len))
        self.assertEqual(outputs["z"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
        self.assertEqual(outputs["z_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
        self.assertEqual(outputs["m_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len))
        self.assertEqual(outputs["logs_p"].shape, (batch_size, config.model_args.hidden_channels, feat_len))

    def test_inference(self):
        num_speakers = 0
        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
        model = Vits(config).to(device)

        batch_size = 1
        input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
        outputs = model.inference(input_dummy)
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)

        batch_size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size)
        outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths})
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)

    def test_multispeaker_inference(self):
        num_speakers = 10
        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True)
        model = Vits(config).to(device)

        batch_size = 1
        input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
        outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids})
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)

        batch_size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
        outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids})
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)

    def test_multilingual_inference(self):
        num_speakers = 10
        num_langs = 3
        args = VitsArgs(language_ids_file=LANG_FILE, use_language_embedding=True, spec_segment_size=10)
        config = VitsConfig(num_speakers=num_speakers, use_speaker_embedding=True, model_args=args)
        model = Vits(config).to(device)

        input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
        speaker_ids = torch.randint(0, num_speakers, (1,)).long().to(device)
        lang_ids = torch.randint(0, num_langs, (1,)).long().to(device)
        _ = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})

        batch_size = 1
        input_dummy, *_ = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
        lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
        outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids, "language_ids": lang_ids})
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)

        batch_size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size)
        speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device)
        lang_ids = torch.randint(0, num_langs, (batch_size,)).long().to(device)
        outputs = model.inference(
            input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids, "language_ids": lang_ids}
        )
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=batch_size)

    def test_d_vector_inference(self):
        args = VitsArgs(
            spec_segment_size=10,
            num_chars=32,
            use_d_vector_file=True,
            d_vector_dim=256,
            d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
        )
        config = VitsConfig(model_args=args)
        model = Vits.init_from_config(config, verbose=False).to(device)
        model.eval()
        # batch size = 1
        input_dummy = torch.randint(0, 24, (1, 128)).long().to(device)
        d_vectors = torch.randn(1, 256).to(device)
        outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors})
        self._check_inference_outputs(config, outputs, input_dummy)
        # batch size = 2
        input_dummy, input_lengths, *_ = self._create_inputs(config)
        d_vectors = torch.randn(2, 256).to(device)
        outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors})
        self._check_inference_outputs(config, outputs, input_dummy, batch_size=2)

    @staticmethod
    def _check_parameter_changes(model, model_ref):
        count = 0
        for item1, item2 in zip(model.named_parameters(), model_ref.named_parameters()):
            name = item1[0]
            param = item1[1]
            param_ref = item2[1]
            assert (param != param_ref).any(), "param {} with shape {} not updated!! 
{}
{}".format(
                name, param.shape, param, param_ref
            )
            count = count + 1

    def _create_batch(self, config, batch_size):
        input_dummy, input_lengths, mel, spec, mel_lengths, _ = self._create_inputs(config, batch_size)
        batch = {}
        batch["tokens"] = input_dummy
        batch["token_lens"] = input_lengths
        batch["spec_lens"] = mel_lengths
        batch["mel_lens"] = mel_lengths
        batch["spec"] = spec
        batch["mel"] = mel
        batch["waveform"] = torch.rand(batch_size, 1, config.audio["sample_rate"] * 10).to(device)
        batch["d_vectors"] = None
        batch["speaker_ids"] = None
        batch["language_ids"] = None
        return batch

    def test_train_step(self):
        # setup the model
        with torch.autograd.set_detect_anomaly(True):
            config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
            model = Vits(config).to(device)
            model.train()
            # model to train
            optimizers = model.get_optimizer()
            criterions = model.get_criterion()
            criterions = [criterions[0].to(device), criterions[1].to(device)]
            # reference model to compare model weights
            model_ref = Vits(config).to(device)
            # # pass the state to ref model
            model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
            count = 0
            for param, param_ref in zip(model.parameters(), model_ref.parameters()):
                assert (param - param_ref).sum() == 0, param
                count = count + 1
            for _ in range(5):
                batch = self._create_batch(config, 2)
                for idx in [0, 1]:
                    outputs, loss_dict = model.train_step(batch, criterions, idx)
                    self.assertFalse(not outputs)
                    self.assertFalse(not loss_dict)
                    loss_dict["loss"].backward()
                    optimizers[idx].step()
                    optimizers[idx].zero_grad()

        # check parameter changes
        self._check_parameter_changes(model, model_ref)

    def test_train_step_upsampling(self):
        """Upsampling by the decoder upsampling layers"""
        # setup the model
        with torch.autograd.set_detect_anomaly(True):
            audio_config = VitsAudioConfig(sample_rate=22050)
            model_args = VitsArgs(
                num_chars=32,
                spec_segment_size=10,
                encoder_sample_rate=11025,
                interpolate_z=False,
                upsample_rates_decoder=[8, 8, 4, 2],
            )
            config = VitsConfig(model_args=model_args, audio=audio_config)
            model = Vits(config).to(device)
            model.train()
            # model to train
            optimizers = model.get_optimizer()
            criterions = model.get_criterion()
            criterions = [criterions[0].to(device), criterions[1].to(device)]
            # reference model to compare model weights
            model_ref = Vits(config).to(device)
            # # pass the state to ref model
            model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
            count = 0
            for param, param_ref in zip(model.parameters(), model_ref.parameters()):
                assert (param - param_ref).sum() == 0, param
                count = count + 1
            for _ in range(5):
                batch = self._create_batch(config, 2)
                for idx in [0, 1]:
                    outputs, loss_dict = model.train_step(batch, criterions, idx)
                    self.assertFalse(not outputs)
                    self.assertFalse(not loss_dict)
                    loss_dict["loss"].backward()
                    optimizers[idx].step()
                    optimizers[idx].zero_grad()

        # check parameter changes
        self._check_parameter_changes(model, model_ref)

    def test_train_step_upsampling_interpolation(self):
        """Upsampling by interpolation"""
        # setup the model
        with torch.autograd.set_detect_anomaly(True):
            audio_config = VitsAudioConfig(sample_rate=22050)
            model_args = VitsArgs(
                num_chars=32,
                spec_segment_size=10,
                encoder_sample_rate=11025,
                interpolate_z=True,
                upsample_rates_decoder=[8, 8, 2, 2],
            )
            config = VitsConfig(model_args=model_args, audio=audio_config)
            model = Vits(config).to(device)
            model.train()
            # model to train
            optimizers = model.get_optimizer()
            criterions = model.get_criterion()
            criterions = [criterions[0].to(device), criterions[1].to(device)]
            # reference model to compare model weights
            model_ref = Vits(config).to(device)
            # # pass the state to ref model
            model_ref.load_state_dict(copy.deepcopy(model.state_dict()))
            count = 0
            for param, param_ref in zip(model.parameters(), model_ref.parameters()):
                assert (param - param_ref).sum() == 0, param
                count = count + 1
            for _ in range(5):
                batch = self._create_batch(config, 2)
                for idx in [0, 1]:
                    outputs, loss_dict = model.train_step(batch, criterions, idx)
                    self.assertFalse(not outputs)
                    self.assertFalse(not loss_dict)
                    loss_dict["loss"].backward()
                    optimizers[idx].step()
                    optimizers[idx].zero_grad()

        # check parameter changes
        self._check_parameter_changes(model, model_ref)

    def test_train_eval_log(self):
        batch_size = 2
        config = VitsConfig(model_args=VitsArgs(num_chars=32, spec_segment_size=10))
        model = Vits.init_from_config(config, verbose=False).to(device)
        model.run_data_dep_init = False
        model.train()
        batch = self._create_batch(config, batch_size)
        logger = TensorboardLogger(
            log_dir=os.path.join(get_tests_output_path(), "dummy_vits_logs"), model_name="vits_test_train_log"
        )
        criterion = model.get_criterion()
        criterion = [criterion[0].to(device), criterion[1].to(device)]
        outputs = [None] * 2
        outputs[0], _ = model.train_step(batch, criterion, 0)
        outputs[1], _ = model.train_step(batch, criterion, 1)
        model.train_log(batch, outputs, logger, None, 1)

        model.eval_log(batch, outputs, logger, None, 1)
        logger.finish()

    def test_test_run(self):
        config = VitsConfig(model_args=VitsArgs(num_chars=32))
        model = Vits.init_from_config(config, verbose=False).to(device)
        model.run_data_dep_init = False
        model.eval()
        test_figures, test_audios = model.test_run(None)
        self.assertTrue(test_figures is not None)
        self.assertTrue(test_audios is not None)

    def test_load_checkpoint(self):
        chkp_path = os.path.join(get_tests_output_path(), "dummy_glow_tts_checkpoint.pth")
        config = VitsConfig(VitsArgs(num_chars=32))
        model = Vits.init_from_config(config, verbose=False).to(device)
        chkp = {}
        chkp["model"] = model.state_dict()
        torch.save(chkp, chkp_path)
        model.load_checkpoint(config, chkp_path)
        self.assertTrue(model.training)
        model.load_checkpoint(config, chkp_path, eval=True)
        self.assertFalse(model.training)

    def test_get_criterion(self):
        config = VitsConfig(VitsArgs(num_chars=32))
        model = Vits.init_from_config(config, verbose=False).to(device)
        criterion = model.get_criterion()
        self.assertTrue(criterion is not None)

    def test_init_from_config(self):
        config = VitsConfig(model_args=VitsArgs(num_chars=32))
        model = Vits.init_from_config(config, verbose=False).to(device)

        config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2))
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertTrue(not hasattr(model, "emb_g"))

        config = VitsConfig(model_args=VitsArgs(num_chars=32, num_speakers=2, use_speaker_embedding=True))
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertEqual(model.num_speakers, 2)
        self.assertTrue(hasattr(model, "emb_g"))

        config = VitsConfig(
            model_args=VitsArgs(
                num_chars=32,
                num_speakers=2,
                use_speaker_embedding=True,
                speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"),
            )
        )
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertEqual(model.num_speakers, 10)
        self.assertTrue(hasattr(model, "emb_g"))

        config = VitsConfig(
            model_args=VitsArgs(
                num_chars=32,
                use_d_vector_file=True,
                d_vector_dim=256,
                d_vector_file=[os.path.join(get_tests_data_path(), "dummy_speakers.json")],
            )
        )
        model = Vits.init_from_config(config, verbose=False).to(device)
        self.assertTrue(model.num_speakers == 1)
        self.assertTrue(not hasattr(model, "emb_g"))
        self.assertTrue(model.embedded_speaker_dim == config.d_vector_dim)