Back to Repositories

Validating Custom Loss Functions in Faceswap Deep Learning Framework

This test suite validates the custom loss functions implemented in the Faceswap deep learning project. It ensures proper functionality of various loss calculations used for training deepfake models, including specialized losses like DSSIM, Focal Frequency, and Gradient Loss.

Test Coverage Overview

The test suite provides comprehensive coverage of multiple loss function implementations:
  • Basic shape validation for GeneralizedLoss, GradientLoss, GMSDLoss, and LInfNorm
  • Loss wrapper functionality testing with 12 different loss functions
  • Edge cases handling with different input tensor shapes
  • Integration with TensorFlow/Keras backend operations

Implementation Analysis

The testing approach uses pytest’s parametrize feature for efficient test case management. Tests verify both standalone loss functions and a wrapper class that combines multiple losses with different weights. Implementation leverages numpy random arrays to simulate realistic input tensors and validates output shapes and datatypes.

Technical Details

Testing tools and configuration:
  • pytest framework for test organization
  • TensorFlow Keras backend for tensor operations
  • NumPy for array manipulation and validation
  • Custom loss wrapper implementation for combined loss calculations
  • Parameterized testing for multiple loss functions

Best Practices Demonstrated

The test suite exhibits several testing best practices:
  • Systematic parameterization of test cases
  • Clear test function naming and organization
  • Comprehensive validation of output types and values
  • Proper handling of framework-specific components
  • Modular test structure for easy maintenance

deepfakes/faceswap

tests/lib/model/losses_test.py

            
#!/usr/bin/env python3
""" Tests for Faceswap Losses.

Adapted from Keras tests.
"""

import pytest
import numpy as np

# Ignore linting errors from Tensorflow's thoroughly broken import system
from tensorflow.keras import backend as K, losses as k_losses  # noqa:E501  # pylint:disable=import-error


from lib.model import losses
from lib.utils import get_backend

_PARAMS = [(losses.GeneralizedLoss(), (2, 16, 16)),
           (losses.GradientLoss(), (2, 16, 16)),
           # TODO Make sure these output dimensions are correct
           (losses.GMSDLoss(), (2, 1, 1)),
           # TODO Make sure these output dimensions are correct
           (losses.LInfNorm(), (2, 1, 1))]
_IDS = ["GeneralizedLoss", "GradientLoss", "GMSDLoss", "LInfNorm"]
_IDS = [f"{loss}[{get_backend().upper()}]" for loss in _IDS]


@pytest.mark.parametrize(["loss_func", "output_shape"], _PARAMS, ids=_IDS)
def test_loss_output(loss_func, output_shape):
    """ Basic shape tests for loss functions. """
    y_a = K.variable(np.random.random((2, 16, 16, 3)))
    y_b = K.variable(np.random.random((2, 16, 16, 3)))
    objective_output = loss_func(y_a, y_b)
    output = objective_output.numpy()
    assert output.dtype == "float32" and not np.any(np.isnan(output))


_LWPARAMS = [losses.DSSIMObjective(),
             losses.FocalFrequencyLoss(),
             losses.GeneralizedLoss(),
             losses.GMSDLoss(),
             losses.GradientLoss(),
             losses.LaplacianPyramidLoss(),
             losses.LDRFLIPLoss(),
             losses.LInfNorm(),
             k_losses.logcosh,  # pylint:disable=no-member
             k_losses.mean_absolute_error,
             k_losses.mean_squared_error,
             losses.MSSIMLoss()]
_LWIDS = ["DSSIMObjective", "FocalFrequencyLoss", "GeneralizedLoss", "GMSDLoss", "GradientLoss",
          "LaplacianPyramidLoss", "LInfNorm", "LDRFlipLoss", "logcosh", "mae", "mse", "MS-SSIM"]
_LWIDS = [f"{loss}[{get_backend().upper()}]" for loss in _LWIDS]


@pytest.mark.parametrize("loss_func", _LWPARAMS, ids=_LWIDS)
def test_loss_wrapper(loss_func):
    """ Test penalized loss wrapper works as expected """
    y_a = K.variable(np.random.random((2, 64, 64, 4)))
    y_b = K.variable(np.random.random((2, 64, 64, 3)))
    p_loss = losses.LossWrapper()
    p_loss.add_loss(loss_func, 1.0, -1)
    p_loss.add_loss(k_losses.mean_squared_error, 2.0, 3)
    output = p_loss(y_a, y_b)
    output = output.numpy()
    assert output.dtype == "float32" and not np.any(np.isnan(output))