Back to Repositories

Testing Observation Space Filtering Implementation in OpenAI Gym

This test suite validates the FilterObservation wrapper functionality in OpenAI Gym, focusing on filtering observation spaces in environment implementations. The tests verify proper filtering of observation keys and error handling for invalid configurations.

Test Coverage Overview

The test suite provides comprehensive coverage of the FilterObservation wrapper functionality:

  • Validates filtering of observation keys in Dict observation spaces
  • Tests multiple combinations of observation and filter keys
  • Verifies handling of None filter keys
  • Ensures proper space transformation in wrapped environments

Implementation Analysis

The testing approach utilizes pytest’s parametrized testing to efficiently verify multiple test cases:

The implementation leverages a FakeEnvironment class that simulates a gym environment with configurable observation spaces. Test cases cover both valid filtering scenarios and error conditions using pytest’s built-in exception testing.

Technical Details

  • Testing Framework: pytest
  • Key Dependencies: numpy, gym
  • Custom Components: FakeEnvironment class implementing gym.Env
  • Test Fixtures: Parametrized test cases for both successful and error scenarios

Best Practices Demonstrated

The test suite exemplifies several testing best practices:

  • Comprehensive error case coverage with specific assertion messages
  • Use of parametrized tests to reduce code duplication
  • Clear separation of test cases into success and error scenarios
  • Proper type hinting and documentation

openai/gym

tests/wrappers/test_filter_observation.py

            
from typing import Optional, Tuple

import numpy as np
import pytest

import gym
from gym import spaces
from gym.wrappers.filter_observation import FilterObservation


class FakeEnvironment(gym.Env):
    def __init__(
        self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",)
    ):
        self.observation_space = spaces.Dict(
            {
                name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
                for name in observation_keys
            }
        )
        self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
        self.render_mode = render_mode

    def render(self, mode="human"):
        image_shape = (32, 32, 3)
        return np.zeros(image_shape, dtype=np.uint8)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        observation = self.observation_space.sample()
        return observation, {}

    def step(self, action):
        del action
        observation = self.observation_space.sample()
        reward, terminal, info = 0.0, False, {}
        return observation, reward, terminal, info


FILTER_OBSERVATION_TEST_CASES = (
    (("key1", "key2"), ("key1",)),
    (("key1", "key2"), ("key1", "key2")),
    (("key1",), None),
    (("key1",), ("key1",)),
)

ERROR_TEST_CASES = (
    ("key", ValueError, "All the filter_keys must be included..*"),
    (False, TypeError, "'bool' object is not iterable"),
    (1, TypeError, "'int' object is not iterable"),
)


class TestFilterObservation:
    @pytest.mark.parametrize(
        "observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
    )
    def test_filter_observation(self, observation_keys, filter_keys):
        env = FakeEnvironment(observation_keys=observation_keys)

        # Make sure we are testing the right environment for the test.
        observation_space = env.observation_space
        assert isinstance(observation_space, spaces.Dict)

        wrapped_env = FilterObservation(env, filter_keys=filter_keys)

        assert isinstance(wrapped_env.observation_space, spaces.Dict)

        if filter_keys is None:
            filter_keys = tuple(observation_keys)

        assert len(wrapped_env.observation_space.spaces) == len(filter_keys)
        assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)

        # Check that the added space item is consistent with the added observation.
        observation, info = wrapped_env.reset()
        assert len(observation) == len(filter_keys)
        assert isinstance(info, dict)

    @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
    def test_raises_with_incorrect_arguments(
        self, filter_keys, error_type, error_match
    ):
        env = FakeEnvironment(observation_keys=("key1", "key2"))

        with pytest.raises(error_type, match=error_match):
            FilterObservation(env, filter_keys=filter_keys)