Back to Repositories

Testing ArcFace Neural Network Architecture Components in GFPGAN

This test suite validates the implementation of ArcFace architecture components in GFPGAN, focusing on ResNetArcFace model and its building blocks. The tests verify the functionality of basic blocks, bottleneck layers, and the complete ResNet architecture with and without SE blocks.

Test Coverage Overview

The test suite provides comprehensive coverage of the ArcFace architecture components:
  • ResNetArcFace model testing with and without SE blocks
  • BasicBlock validation with various stride and downsample configurations
  • Bottleneck layer testing with different input/output dimensions
  • GPU compatibility verification for all components

Implementation Analysis

The testing approach employs PyTorch’s testing paradigms with GPU acceleration support. Each test validates shape transformations and architectural configurations, ensuring proper tensor operations and dimensional consistency across layers.

The implementation uses explicit CUDA device placement and includes checks for both regular and downsampled operations.

Technical Details

  • Testing Framework: PyTorch’s native testing utilities
  • Hardware Requirements: CUDA-capable GPU
  • Input Tensors: Randomly generated using torch.rand
  • Validation Method: Shape assertion checks
  • Key Components: ResNetArcFace, BasicBlock, Bottleneck

Best Practices Demonstrated

The test suite exhibits several testing best practices:
  • Isolated component testing for each architectural block
  • Explicit shape validation for tensor operations
  • Hardware-specific testing considerations
  • Clear test case organization with descriptive docstrings
  • Comprehensive coverage of different configurations

tencentarc/gfpgan

tests/test_arcface_arch.py

            
import torch

from gfpgan.archs.arcface_arch import BasicBlock, Bottleneck, ResNetArcFace


def test_resnetarcface():
    """Test arch: ResNetArcFace."""

    # model init and forward (gpu)
    if torch.cuda.is_available():
        net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=True).cuda().eval()
        img = torch.rand((1, 1, 128, 128), dtype=torch.float32).cuda()
        output = net(img)
        assert output.shape == (1, 512)

        # -------------------- without SE block ----------------------- #
        net = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False).cuda().eval()
        output = net(img)
        assert output.shape == (1, 512)


def test_basicblock():
    """Test the BasicBlock in arcface_arch"""
    block = BasicBlock(1, 3, stride=1, downsample=None).cuda()
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
    output = block(img)
    assert output.shape == (1, 3, 12, 12)

    # ----------------- use the downsmaple module--------------- #
    downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
    block = BasicBlock(1, 3, stride=2, downsample=downsample).cuda()
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
    output = block(img)
    assert output.shape == (1, 3, 6, 6)


def test_bottleneck():
    """Test the Bottleneck in arcface_arch"""
    block = Bottleneck(1, 1, stride=1, downsample=None).cuda()
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
    output = block(img)
    assert output.shape == (1, 4, 12, 12)

    # ----------------- use the downsmaple module--------------- #
    downsample = torch.nn.UpsamplingNearest2d(scale_factor=0.5).cuda()
    block = Bottleneck(1, 1, stride=2, downsample=downsample).cuda()
    img = torch.rand((1, 1, 12, 12), dtype=torch.float32).cuda()
    output = block(img)
    assert output.shape == (1, 4, 6, 6)