Back to Repositories

Testing Nested Dictionary Observation Wrappers in OpenAI Gym

This test suite validates the functionality of nested dictionary observation handling in OpenAI Gym environments, focusing on the FilterObservation and FlattenObservation wrappers. The tests ensure proper processing of complex observation spaces with nested dictionaries, tuples, and boxes.

Test Coverage Overview

The test suite provides comprehensive coverage for nested dictionary observations in Gym environments.

Key areas tested include:
  • Various nested dictionary structures with different shapes and types
  • Observation space flattening operations
  • Complex data type handling including Box, Dict, and Tuple spaces
  • Shape validation across different observation configurations

Implementation Analysis

The testing approach uses pytest’s parametrize feature to validate multiple test cases efficiently. The implementation follows a structured pattern using a FakeEnvironment class that simulates different observation space configurations.

Technical patterns include:
  • Parametrized test cases for different observation space structures
  • Wrapper composition with FilterObservation and FlattenObservation
  • Dynamic shape verification for flattened observations

Technical Details

Testing tools and configuration:
  • pytest framework for test organization
  • numpy for array operations and validation
  • gym.spaces components (Box, Dict, Tuple)
  • Custom FakeEnvironment class for controlled testing
  • Type hints and optional parameters for enhanced code clarity

Best Practices Demonstrated

The test suite exemplifies several testing best practices in Python and Gym environment testing.

Notable practices include:
  • Systematic test case organization using pytest fixtures
  • Comprehensive edge case coverage through parametrized tests
  • Clean separation of test environment setup and actual test logic
  • Strong type checking and space validation

openai/gym

tests/wrappers/test_nested_dict.py

            
"""Tests for the filter observation wrapper."""
from typing import Optional

import numpy as np
import pytest

import gym
from gym.spaces import Box, Dict, Tuple
from gym.wrappers import FilterObservation, FlattenObservation


class FakeEnvironment(gym.Env):
    def __init__(self, observation_space, render_mode=None):
        self.observation_space = observation_space
        self.obs_keys = self.observation_space.spaces.keys()
        self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32)
        self.render_mode = render_mode

    def render(self, mode="human"):
        image_shape = (32, 32, 3)
        return np.zeros(image_shape, dtype=np.uint8)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        observation = self.observation_space.sample()
        return observation, {}

    def step(self, action):
        del action
        observation = self.observation_space.sample()
        reward, terminal, info = 0.0, False, {}
        return observation, reward, terminal, info


NESTED_DICT_TEST_CASES = (
    (
        Dict(
            {
                "key1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                "key2": Dict(
                    {
                        "subkey1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                        "subkey2": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                    }
                ),
            }
        ),
        (6,),
    ),
    (
        Dict(
            {
                "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (9,),
    ),
    (
        Dict(
            {
                "key1": Tuple(
                    (
                        Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                        Box(shape=(2,), low=-1, high=1, dtype=np.float32),
                    )
                ),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (7,),
    ),
    (
        Dict(
            {
                "key1": Tuple((Box(shape=(2,), low=-1, high=1, dtype=np.float32),)),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (5,),
    ),
    (
        Dict(
            {
                "key1": Tuple(
                    (Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),)
                ),
                "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
                "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
            }
        ),
        (5,),
    ),
)


class TestNestedDictWrapper:
    @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
    def test_nested_dicts_size(self, observation_space, flat_shape):
        env = FakeEnvironment(observation_space=observation_space)

        # Make sure we are testing the right environment for the test.
        observation_space = env.observation_space
        assert isinstance(observation_space, Dict)

        wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
        assert wrapped_env.observation_space.shape == flat_shape

        assert wrapped_env.observation_space.dtype == np.float32

    @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
    def test_nested_dicts_ravel(self, observation_space, flat_shape):
        env = FakeEnvironment(observation_space=observation_space)
        wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
        obs, info = wrapped_env.reset()
        assert obs.shape == wrapped_env.observation_space.shape
        assert isinstance(info, dict)