Back to Repositories

Validating Box Space Implementation in OpenAI Gym

This test suite validates the Box space implementation in OpenAI Gym, focusing on shape inference, boundary conditions, and data type handling. The tests ensure proper initialization, validation, and functionality of Box spaces which are fundamental for defining continuous action and observation spaces in reinforcement learning environments.

Test Coverage Overview

The test suite provides comprehensive coverage of Box space functionality including:
  • Shape inference validation across different input dimensions
  • Boundary value testing for low/high limits
  • Constructor error handling and validation
  • Data type compatibility checks
  • Infinite space boundary testing
  • State serialization verification

Implementation Analysis

The testing approach utilizes pytest’s parametrized testing to systematically verify Box space behavior. The implementation focuses on edge cases and validation using numpy arrays, handling various data types and shapes while ensuring proper space constraints and sampling behavior.

Technical Details

Key technical components include:
  • pytest framework for test organization
  • numpy for array operations and data types
  • Custom assertion logic for space validation
  • Regular expressions for error message validation
  • Warning capture mechanisms

Best Practices Demonstrated

The test suite exemplifies strong testing practices through:
  • Comprehensive parameterized test cases
  • Explicit error message validation
  • Thorough edge case coverage
  • Clear test function organization
  • Proper test isolation and setup

openai/gym

tests/spaces/test_box.py

            
import re
import warnings

import numpy as np
import pytest

import gym.error
from gym.spaces import Box
from gym.spaces.box import get_inf


@pytest.mark.parametrize(
    "box,expected_shape",
    [
        (  # Test with same 1-dim low and high shape
            Box(low=np.zeros(2), high=np.ones(2), dtype=np.int32),
            (2,),
        ),
        (  # Test with same multi-dim low and high shape
            Box(low=np.zeros((2, 1)), high=np.ones((2, 1)), dtype=np.int32),
            (2, 1),
        ),
        (  # Test with scalar low high and different shape
            Box(low=0, high=1, shape=(5, 2)),
            (5, 2),
        ),
        (Box(low=0, high=1), (1,)),  # Test with int and int
        (Box(low=0.0, high=1.0), (1,)),  # Test with float and float
        (Box(low=np.zeros(1)[0], high=np.ones(1)[0]), (1,)),
        (Box(low=0.0, high=1), (1,)),  # Test with float and int
        (Box(low=0, high=np.int32(1)), (1,)),  # Test with python int and numpy int32
        (Box(low=0, high=np.ones(3)), (3,)),  # Test with array and scalar
        (Box(low=np.zeros(3), high=1.0), (3,)),  # Test with array and scalar
    ],
)
def test_shape_inference(box, expected_shape):
    """Test that the shape inference is as expected."""
    assert box.shape == expected_shape
    assert box.sample().shape == expected_shape


@pytest.mark.parametrize(
    "value,valid",
    [
        (1, True),
        (1.0, True),
        (np.int32(1), True),
        (np.float32(1.0), True),
        (np.zeros(2, dtype=np.float32), True),
        (np.zeros((2, 2), dtype=np.float32), True),
        (np.inf, True),
        (np.nan, True),  # This is a weird case that we allow
        (True, False),
        (np.bool8(True), False),
        (1 + 1j, False),
        (np.complex128(1 + 1j), False),
        ("string", False),
    ],
)
def test_low_high_values(value, valid: bool):
    """Test what `low` and `high` values are valid for `Box` space."""
    if valid:
        with warnings.catch_warnings(record=True) as caught_warnings:
            Box(low=value, high=value)
        assert len(caught_warnings) == 0, tuple(
            warning.message for warning in caught_warnings
        )
    else:
        with pytest.raises(
            ValueError,
            match=re.escape(
                "expect their types to be np.ndarray, an integer or a float"
            ),
        ):
            Box(low=value, high=value)


@pytest.mark.parametrize(
    "low,high,kwargs,error,message",
    [
        (
            0,
            1,
            {"dtype": None},
            AssertionError,
            "Box dtype must be explicitly provided, cannot be None.",
        ),
        (
            0,
            1,
            {"shape": (None,)},
            AssertionError,
            "Expect all shape elements to be an integer, actual type: (<class 'NoneType'>,)",
        ),
        (
            0,
            1,
            {
                "shape": (
                    1,
                    None,
                )
            },
            AssertionError,
            "Expect all shape elements to be an integer, actual type: (<class 'int'>, <class 'NoneType'>)",
        ),
        (
            0,
            1,
            {
                "shape": (
                    np.int64(1),
                    None,
                )
            },
            AssertionError,
            "Expect all shape elements to be an integer, actual type: (<class 'numpy.int64'>, <class 'NoneType'>)",
        ),
        (
            None,
            None,
            {},
            ValueError,
            "Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: <class 'NoneType'>, high: <class 'NoneType'>",
        ),
        (
            0,
            None,
            {},
            ValueError,
            "Box shape is inferred from low and high, expect their types to be np.ndarray, an integer or a float, actual type low: <class 'int'>, high: <class 'NoneType'>",
        ),
        (
            np.zeros(3),
            np.ones(2),
            {},
            AssertionError,
            "high.shape doesn't match provided shape, high.shape: (2,), shape: (3,)",
        ),
    ],
)
def test_init_errors(low, high, kwargs, error, message):
    """Test all constructor errors."""
    with pytest.raises(error, match=f"^{re.escape(message)}$"):
        Box(low=low, high=high, **kwargs)


def test_dtype_check():
    """Tests the Box contains function with different dtypes."""
    # Related Issues:
    # https://github.com/openai/gym/issues/2357
    # https://github.com/openai/gym/issues/2298

    space = Box(0, 1, (), dtype=np.float32)

    # casting will match the correct type
    assert np.array(0.5, dtype=np.float32) in space

    # float16 is in float32 space
    assert np.array(0.5, dtype=np.float16) in space

    # float64 is not in float32 space
    assert np.array(0.5, dtype=np.float64) not in space


@pytest.mark.parametrize(
    "space",
    [
        Box(low=0, high=np.inf, shape=(2,), dtype=np.int32),
        Box(low=0, high=np.inf, shape=(2,), dtype=np.float32),
        Box(low=0, high=np.inf, shape=(2,), dtype=np.int64),
        Box(low=0, high=np.inf, shape=(2,), dtype=np.float64),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.int32),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.float32),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.int64),
        Box(low=-np.inf, high=0, shape=(2,), dtype=np.float64),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int32),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float32),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.int64),
        Box(low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int32),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float32),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.int64),
        Box(low=0, high=np.inf, shape=(2, 3), dtype=np.float64),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int32),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float32),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.int64),
        Box(low=-np.inf, high=0, shape=(2, 3), dtype=np.float64),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int32),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float32),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.int64),
        Box(low=-np.inf, high=np.inf, shape=(2, 3), dtype=np.float64),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int32),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float32),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.int64),
        Box(low=np.array([-np.inf, 0]), high=np.array([0.0, np.inf]), dtype=np.float64),
    ],
)
def test_infinite_space(space):
    """
    To test spaces that are passed in have only 0 or infinite bounds because `space.high` and `space.low`
     are both modified within the init, we check for infinite when we know it's not 0
    """

    assert np.all(
        space.low < space.high
    ), f"Box low bound ({space.low}) is not lower than the high bound ({space.high})"

    space.seed(0)
    sample = space.sample()

    # check if space contains sample
    assert (
        sample in space
    ), f"Sample ({sample}) not inside space according to `space.contains()`"

    # manually check that the sign of the sample is within the bounds
    assert np.all(
        np.sign(sample) <= np.sign(space.high)
    ), f"Sign of sample ({sample}) is less than space upper bound ({space.high})"
    assert np.all(
        np.sign(space.low) <= np.sign(sample)
    ), f"Sign of sample ({sample}) is more than space lower bound ({space.low})"

    # check that int bounds are bounded for everything
    # but floats are unbounded for infinite
    if np.any(space.high != 0):
        assert (
            space.is_bounded("above") is False
        ), "inf upper bound supposed to be unbounded"
    else:
        assert (
            space.is_bounded("above") is True
        ), "non-inf upper bound supposed to be bounded"

    if np.any(space.low != 0):
        assert (
            space.is_bounded("below") is False
        ), "inf lower bound supposed to be unbounded"
    else:
        assert (
            space.is_bounded("below") is True
        ), "non-inf lower bound supposed to be bounded"

    if np.any(space.low != 0) or np.any(space.high != 0):
        assert space.is_bounded("both") is False
    else:
        assert space.is_bounded("both") is True

    # check for dtype
    assert (
        space.high.dtype == space.dtype
    ), f"High's dtype {space.high.dtype} doesn't match `space.dtype`'"
    assert (
        space.low.dtype == space.dtype
    ), f"Low's dtype {space.high.dtype} doesn't match `space.dtype`'"

    with pytest.raises(
        ValueError, match="manner is not in {'below', 'above', 'both'}, actual value:"
    ):
        space.is_bounded("test")


def test_legacy_state_pickling():
    legacy_state = {
        "dtype": np.dtype("float32"),
        "_shape": (5,),
        "low": np.array([0.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32),
        "high": np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32),
        "bounded_below": np.array([True, True, True, True, True]),
        "bounded_above": np.array([True, True, True, True, True]),
        "_np_random": None,
    }

    b = Box(-1, 1, ())
    assert "low_repr" in b.__dict__ and "high_repr" in b.__dict__
    del b.__dict__["low_repr"]
    del b.__dict__["high_repr"]
    assert "low_repr" not in b.__dict__ and "high_repr" not in b.__dict__

    b.__setstate__(legacy_state)
    assert b.low_repr == "0.0"
    assert b.high_repr == "1.0"


def test_get_inf():
    """Tests that get inf function works as expected, primarily for coverage."""
    assert get_inf(np.float32, "+") == np.inf
    assert get_inf(np.float16, "-") == -np.inf
    with pytest.raises(
        TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
    ):
        get_inf(np.float32, "*")

    assert get_inf(np.int16, "+") == 32765
    assert get_inf(np.int8, "-") == -126
    with pytest.raises(
        TypeError, match=re.escape("Unknown sign *, use either '+' or '-'")
    ):
        get_inf(np.int32, "*")

    with pytest.raises(
        ValueError,
        match=re.escape("Unknown dtype <class 'numpy.complex128'> for infinite bounds"),
    ):
        get_inf(np.complex_, "+")


def test_sample_mask():
    """Box cannot have a mask applied."""
    space = Box(0, 1)
    with pytest.raises(
        gym.error.Error,
        match=re.escape("Box.sample cannot be provided a mask, actual value: "),
    ):
        space.sample(mask=np.array([0, 1, 0], dtype=np.int8))