Validating Speaker Encoder Models and Loss Functions in Coqui-AI TTS
This test suite validates the functionality of speaker encoder models and associated loss functions in the Coqui-AI TTS system. It includes comprehensive testing of LSTM and ResNet speaker encoders along with various loss computation methods like GE2E, AngleProto, and SoftmaxAngleProto losses.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
coqui-ai/tts
tests/aux_tests/test_speaker_encoder.py
import unittest
import torch as T
from tests import get_tests_input_path
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
file_path = get_tests_input_path()
class LSTMSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
# computing d vectors
output = model.forward(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 256
output = model.inference(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 256
# compute d vectors by passing LSTM hidden
# output = model.forward(dummy_input, dummy_hidden)
# assert output.shape[0] == 4
# assert output.shape[1] == 20
# assert output.shape[2] == 256
# check normalization
output_norm = T.nn.functional.normalize(output, dim=1, p=2)
assert_diff = (output_norm - output).sum().item()
assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
dummy_input = T.rand(1, 80, 240) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
assert output.shape[0] == 1
assert output.shape[1] == 256
assert len(output.shape) == 2
class ResNetSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
# computing d vectors
output = model.forward(dummy_input)
assert output.shape[0] == 4
assert output.shape[1] == 256
output = model.forward(dummy_input, l2_norm=True)
assert output.shape[0] == 4
assert output.shape[1] == 256
# check normalization
output_norm = T.nn.functional.normalize(output, dim=1, p=2)
assert_diff = (output_norm - output).sum().item()
assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch
dummy_input = T.rand(1, 80, 240) # B x D x T
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
assert output.shape[0] == 1
assert output.shape[1] == 256
assert len(output.shape) == 2
class GE2ELossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
# check random input
dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method="softmax")
output = loss.forward(dummy_input)
assert output.item() >= 0.0
# check all zeros
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method="softmax")
output = loss.forward(dummy_input)
assert output.item() >= 0.0
# check speaker loss with orthogonal d-vectors
dummy_input = T.empty(3, 64)
dummy_input = T.nn.init.orthogonal_(dummy_input)
dummy_input = T.cat(
[
dummy_input[0].repeat(5, 1, 1).transpose(0, 1),
dummy_input[1].repeat(5, 1, 1).transpose(0, 1),
dummy_input[2].repeat(5, 1, 1).transpose(0, 1),
]
) # num_speaker x num_utterance x dim
loss = GE2ELoss(loss_method="softmax")
output = loss.forward(dummy_input)
assert output.item() < 0.005
class AngleProtoLossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
# check random input
dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim
loss = AngleProtoLoss()
output = loss.forward(dummy_input)
assert output.item() >= 0.0
# check all zeros
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
loss = AngleProtoLoss()
output = loss.forward(dummy_input)
assert output.item() >= 0.0
# check speaker loss with orthogonal d-vectors
dummy_input = T.empty(3, 64)
dummy_input = T.nn.init.orthogonal_(dummy_input)
dummy_input = T.cat(
[
dummy_input[0].repeat(5, 1, 1).transpose(0, 1),
dummy_input[1].repeat(5, 1, 1).transpose(0, 1),
dummy_input[2].repeat(5, 1, 1).transpose(0, 1),
]
) # num_speaker x num_utterance x dim
loss = AngleProtoLoss()
output = loss.forward(dummy_input)
assert output.item() < 0.005
class SoftmaxAngleProtoLossTests(unittest.TestCase):
# pylint: disable=R0201
def test_in_out(self):
embedding_dim = 64
num_speakers = 5
batch_size = 4
dummy_label = T.randint(low=0, high=num_speakers, size=(batch_size, num_speakers))
# check random input
dummy_input = T.rand(batch_size, num_speakers, embedding_dim) # num_speaker x num_utterance x dim
loss = SoftmaxAngleProtoLoss(embedding_dim=embedding_dim, n_speakers=num_speakers)
output = loss.forward(dummy_input, dummy_label)
assert output.item() >= 0.0
# check all zeros
dummy_input = T.ones(batch_size, num_speakers, embedding_dim) # num_speaker x num_utterance x dim
loss = SoftmaxAngleProtoLoss(embedding_dim=embedding_dim, n_speakers=num_speakers)
output = loss.forward(dummy_input, dummy_label)
assert output.item() >= 0.0