Testing PyTorch Parameter Extraction in Stable-Diffusion-WebUI
This test suite validates the PyTorch utility functions in the Stable Diffusion WebUI, focusing on parameter handling and device management. It ensures proper functionality of model parameter access and type conversion across different wrapping scenarios.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
automatic1111/stable-diffusion-webui
test/test_torch_utils.py
import types
import pytest
import torch
from modules import torch_utils
@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
mod = torch.nn.Linear(1, 1)
cpu = torch.device("cpu")
mod.to(dtype=torch.float16, device=cpu)
if wrapped:
# more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod)
p = torch_utils.get_param(mod)
assert p.dtype == torch.float16
assert p.device == cpu