Testing Nested Dictionary Observation Wrappers in OpenAI Gym
This test suite validates the functionality of nested dictionary observation handling in OpenAI Gym environments, focusing on the FilterObservation and FlattenObservation wrappers. The tests ensure proper processing of complex observation spaces with nested dictionaries, tuples, and boxes.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
openai/gym
tests/wrappers/test_nested_dict.py
"""Tests for the filter observation wrapper."""
from typing import Optional
import numpy as np
import pytest
import gym
from gym.spaces import Box, Dict, Tuple
from gym.wrappers import FilterObservation, FlattenObservation
class FakeEnvironment(gym.Env):
def __init__(self, observation_space, render_mode=None):
self.observation_space = observation_space
self.obs_keys = self.observation_space.spaces.keys()
self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32)
self.render_mode = render_mode
def render(self, mode="human"):
image_shape = (32, 32, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation, {}
def step(self, action):
del action
observation = self.observation_space.sample()
reward, terminal, info = 0.0, False, {}
return observation, reward, terminal, info
NESTED_DICT_TEST_CASES = (
(
Dict(
{
"key1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
"key2": Dict(
{
"subkey1": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
"subkey2": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
}
),
(6,),
),
(
Dict(
{
"key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(9,),
),
(
Dict(
{
"key1": Tuple(
(
Box(shape=(2,), low=-1, high=1, dtype=np.float32),
Box(shape=(2,), low=-1, high=1, dtype=np.float32),
)
),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(7,),
),
(
Dict(
{
"key1": Tuple((Box(shape=(2,), low=-1, high=1, dtype=np.float32),)),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(5,),
),
(
Dict(
{
"key1": Tuple(
(Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),)
),
"key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
"key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
}
),
(5,),
),
)
class TestNestedDictWrapper:
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_size(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
# Make sure we are testing the right environment for the test.
observation_space = env.observation_space
assert isinstance(observation_space, Dict)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
assert wrapped_env.observation_space.shape == flat_shape
assert wrapped_env.observation_space.dtype == np.float32
@pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES)
def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
obs, info = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape
assert isinstance(info, dict)