Back to Repositories

Testing PyTorch Parameter Extraction in Stable-Diffusion-WebUI

This test suite validates the PyTorch utility functions in the Stable Diffusion WebUI, focusing on parameter handling and device management. It ensures proper functionality of model parameter access and type conversion across different wrapping scenarios.

Test Coverage Overview

The test coverage focuses on the get_param utility function, verifying parameter extraction from PyTorch models.

  • Tests both wrapped and unwrapped model scenarios
  • Validates parameter dtype preservation
  • Confirms device assignment consistency
  • Covers core model parameter access patterns

Implementation Analysis

The testing approach employs pytest’s parametrize feature to systematically test parameter extraction across different model configurations.

The implementation uses SimpleNamespace for model wrapping simulation, mirroring real-world usage patterns in the Spandrel integration context. The tests verify both direct model access and wrapped model scenarios.

Technical Details

  • pytest framework for test organization
  • PyTorch for neural network operations
  • types.SimpleNamespace for model wrapping
  • torch.nn.Linear as test model
  • CPU device testing configuration
  • float16 dtype verification

Best Practices Demonstrated

The test suite exemplifies robust testing practices through parameterized test cases and explicit assertion checking.

  • Parametrized testing for multiple scenarios
  • Explicit device and dtype verification
  • Clear separation of test cases
  • Simulation of real-world usage patterns

automatic1111/stable-diffusion-webui

test/test_torch_utils.py

            
import types

import pytest
import torch

from modules import torch_utils


@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
    mod = torch.nn.Linear(1, 1)
    cpu = torch.device("cpu")
    mod.to(dtype=torch.float16, device=cpu)
    if wrapped:
        # more or less how spandrel wraps a thing
        mod = types.SimpleNamespace(model=mod)
    p = torch_utils.get_param(mod)
    assert p.dtype == torch.float16
    assert p.device == cpu