Testing Observation Space Filtering Implementation in OpenAI Gym
This test suite validates the FilterObservation wrapper functionality in OpenAI Gym, focusing on filtering observation spaces in environment implementations. The tests verify proper filtering of observation keys and error handling for invalid configurations.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
openai/gym
tests/wrappers/test_filter_observation.py
from typing import Optional, Tuple
import numpy as np
import pytest
import gym
from gym import spaces
from gym.wrappers.filter_observation import FilterObservation
class FakeEnvironment(gym.Env):
def __init__(
self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",)
):
self.observation_space = spaces.Dict(
{
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
for name in observation_keys
}
)
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
self.render_mode = render_mode
def render(self, mode="human"):
image_shape = (32, 32, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation, {}
def step(self, action):
del action
observation = self.observation_space.sample()
reward, terminal, info = 0.0, False, {}
return observation, reward, terminal, info
FILTER_OBSERVATION_TEST_CASES = (
(("key1", "key2"), ("key1",)),
(("key1", "key2"), ("key1", "key2")),
(("key1",), None),
(("key1",), ("key1",)),
)
ERROR_TEST_CASES = (
("key", ValueError, "All the filter_keys must be included..*"),
(False, TypeError, "'bool' object is not iterable"),
(1, TypeError, "'int' object is not iterable"),
)
class TestFilterObservation:
@pytest.mark.parametrize(
"observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
)
def test_filter_observation(self, observation_keys, filter_keys):
env = FakeEnvironment(observation_keys=observation_keys)
# Make sure we are testing the right environment for the test.
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Dict)
wrapped_env = FilterObservation(env, filter_keys=filter_keys)
assert isinstance(wrapped_env.observation_space, spaces.Dict)
if filter_keys is None:
filter_keys = tuple(observation_keys)
assert len(wrapped_env.observation_space.spaces) == len(filter_keys)
assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
# Check that the added space item is consistent with the added observation.
observation, info = wrapped_env.reset()
assert len(observation) == len(filter_keys)
assert isinstance(info, dict)
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
def test_raises_with_incorrect_arguments(
self, filter_keys, error_type, error_match
):
env = FakeEnvironment(observation_keys=("key1", "key2"))
with pytest.raises(error_type, match=error_match):
FilterObservation(env, filter_keys=filter_keys)