Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Gym 0.26.0 #205

Merged
merged 34 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
fff2c75
Support new gym API
Markus28 Oct 12, 2022
d421e1d
Remove gym restriction
Markus28 Oct 12, 2022
5305767
Fixed some tests
Markus28 Oct 12, 2022
3393172
Require packaging
Markus28 Oct 12, 2022
7946014
Updated docs
Oct 13, 2022
8cd2bc6
Formatting and fixing tests
Oct 13, 2022
1e7274f
Add requirement
Markus28 Oct 14, 2022
e82a527
Another dependency
Markus28 Oct 14, 2022
92a5a57
Another dependency
Markus28 Oct 14, 2022
19795e7
Fixed some tests hopefully
Markus28 Oct 14, 2022
44ca94c
Fixed build files, again
Markus28 Oct 14, 2022
ad77457
Removed restriction on Gym version
Markus28 Oct 15, 2022
74b9cf0
fix lint
Trinkle23897 Oct 17, 2022
c8031ea
fix doc
Trinkle23897 Oct 17, 2022
7f58599
upgrade dependency
Trinkle23897 Oct 17, 2022
e03d494
Removed mjc-mwe
Oct 17, 2022
1020616
Formatting
Oct 17, 2022
cd54e0f
Added assert for gym version
Oct 17, 2022
f0c1bba
Removed accidental paste
Oct 17, 2022
de7ee2e
Added gym version asserts
Oct 17, 2022
ad9cf46
Always return info if new gym API. Raise exception in make in case of…
Oct 17, 2022
d1afb27
Fixed tests, linting
Oct 17, 2022
648a55a
Fixed a test, linting
Oct 17, 2022
53182ea
Fixed two more tests
Oct 17, 2022
bef3f6d
Bugfix
Oct 17, 2022
94c36e2
Added MuJoCo to LD_LIBRARY_PATH for CI
Oct 17, 2022
e3980a9
Hopefully fixes CI
Oct 17, 2022
b6e8431
Try to reset gym MuJoCo environment
Oct 17, 2022
3fb850d
Try to reset gym MuJoCo environment
Oct 17, 2022
bf5232f
Hack to use private method
Oct 17, 2022
22d8dd3
test
Trinkle23897 Oct 18, 2022
eaa170e
test
Trinkle23897 Oct 18, 2022
8fd84ca
revert
Trinkle23897 Oct 18, 2022
74d0e83
update setup.cfg
Trinkle23897 Oct 18, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm
build --action_env=BAZEL_LINKOPTS=-static-libgcc
build --action_env=CUDA_DIR=/usr/local/cuda
build --action_env=LD_LIBRARY_PATH=/home/ubuntu/.mujoco/mujoco210/bin
build --incompatible_strict_action_env --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --client_env=BAZEL_CXXOPTS=-std=c++17
build:debug --cxxopt=-DENVPOOL_TEST --compilation_mode=dbg -s
build:test --cxxopt=-DENVPOOL_TEST --copt=-g0 --copt=-O3 --copt=-DNDEBUG --copt=-msse --copt=-msse2 --copt=-mmmx
Expand Down
1 change: 1 addition & 0 deletions benchmark/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ mujoco_py==2.1.2.14
tqdm
opencv-python-headless
dm_control==1.0.3.post1
packaging
4 changes: 2 additions & 2 deletions docker/release.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ RUN go install github.com/bazelbuild/bazelisk@latest && ln -sf $HOME/go/bin/baze

# install big wheels

RUN for i in 7 8 9; do ln -sf /usr/bin/python3.$i /usr/bin/python3; pip3 install torch opencv-python-headless; done
RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; pip3 install torch opencv-python-headless; done

RUN bazel version

Expand All @@ -45,4 +45,4 @@ COPY . .

# compile and test release wheels

RUN for i in 7 8 9; do ln -sf /usr/bin/python3.$i /usr/bin/python3; make pypi-wheel BAZELOPT="--remote_cache=http://bazel-cache.sail:8080"; pip3 install wheelhouse/*cp3$i*.whl; rm dist/*.whl; make release-test; done
RUN for i in 7 8 9 10; do ln -sf /usr/bin/python3.$i /usr/bin/python3; make pypi-wheel BAZELOPT="--remote_cache=http://bazel-cache.sail:8080"; pip3 install wheelhouse/*cp3$i*.whl; rm dist/*.whl; make release-test; done
24 changes: 13 additions & 11 deletions docs/content/python_interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ batched environments:
``gym.Env``, while some environments may not have such an option;
* ``gym_reset_return_info (bool)``: whether to return a tuple of
``(obs, info)`` instead of only ``obs`` when calling reset in ``gym.Env``,
default to ``False``; this option is to adapt the newest version of gym's
defaults to ``False`` if you are using Gym<0.26.0, otherwise it defaults
to ``True``; this option is to adapt the newest version of gym's
interface;
* other configurations such as ``img_height`` / ``img_width`` / ``stack_num``
/ ``frame_skip`` / ``noop_max`` in Atari env, ``reward_metric`` /
Expand Down Expand Up @@ -115,16 +116,17 @@ third case, use ``env.step(action)`` where action is a dictionary.
Data Output Format
------------------

+----------+------------------------------------------------------------------+------------------------------------------------------------------+
| function | gym | dm |
| | | |
+==========+==================================================================+==================================================================+
| reset | | env_id -> obs array (single observation) | env_id -> TimeStep(FIRST, obs|info|env_id, rew=0, discount or 1) |
| | | or an obs dict (multi observation) | |
| | | or (obs, info) tuple (when ``gym_reset_return_info`` == True) | |
+----------+------------------------------------------------------------------+------------------------------------------------------------------+
| step | (obs, rew, done, info|env_id) | TimeStep(StepType, obs|info|env_id, rew, discount or 1 - done) |
+----------+------------------------------------------------------------------+------------------------------------------------------------------+
+----------+----------------------------------------------------------------------+------------------------------------------------------------------+
| function | gym | dm |
| | | |
+==========+======================================================================+==================================================================+
| reset | | env_id -> obs array (single observation) | env_id -> TimeStep(FIRST, obs|info|env_id, rew=0, discount or 1) |
| | | or an obs dict (multi observation) | |
| | | or (obs, info) tuple (when ``gym_reset_return_info`` == True) | |
+----------+----------------------------------------------------------------------+------------------------------------------------------------------+
| step | (obs, rew, done, info|env_id) or | TimeStep(StepType, obs|info|env_id, rew, discount or 1 - done) |
| | (obs, rew, terminated, truncated, info|env_id) (when Gym >= 0.26.0) | |
+----------+----------------------------------------------------------------------+------------------------------------------------------------------+

Note: ``gym.reset()`` doesn't support async step setting because it cannot get
``env_id`` from ``reset()`` function, so it's better to use low-level APIs such
Expand Down
2 changes: 1 addition & 1 deletion envpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
register,
)

__version__ = "0.6.4"
__version__ = "0.6.5"
__all__ = [
"register",
"make",
Expand Down
43 changes: 29 additions & 14 deletions envpool/atari/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from absl import logging
from absl.testing import absltest
from packaging import version

from envpool.atari import AtariDMEnvPool, AtariEnvSpec, AtariGymEnvPool

Expand Down Expand Up @@ -250,63 +251,77 @@ def test_lowlevel_step(self) -> None:
self.assertTrue(isinstance(env, gym.Env))
logging.info(env)
env.async_reset()
obs, rew, done, info = env.recv()
obs, rew, terminated, truncated, info = env.recv()
done = np.logical_or(terminated, truncated)
# check shape
self.assertIsInstance(obs, np.ndarray)
self.assertEqual(obs.dtype, np.uint8)
np.testing.assert_allclose(rew.shape, (num_envs,))
self.assertEqual(rew.dtype, np.float32)
np.testing.assert_allclose(done.shape, (num_envs,))
self.assertEqual(done.dtype, np.bool_)
self.assertEqual(terminated.dtype, np.bool_)
self.assertEqual(truncated.dtype, np.bool_)
self.assertIsInstance(info, dict)
self.assertEqual(len(info), 7)
self.assertEqual(len(info), 6)
self.assertEqual(info["env_id"].dtype, np.int32)
self.assertEqual(info["lives"].dtype, np.int32)
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
self.assertEqual(info["TimeLimit.truncated"].dtype, np.bool_)
np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
np.testing.assert_allclose(info["lives"].shape, (num_envs,))
np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
np.testing.assert_allclose(info["TimeLimit.truncated"].shape, (num_envs,))
np.testing.assert_allclose(truncated.shape, (num_envs,))
while not np.any(done):
env.send(np.random.randint(6, size=num_envs))
obs, rew, done, info = env.recv()
obs, rew, terminated, truncated, info = env.recv()
done = np.logical_or(terminated, truncated)
env.send(np.random.randint(6, size=num_envs))
obs1, rew1, done1, info1 = env.recv()
obs1, rew1, terminated1, truncated1, info1 = env.recv()
done1 = np.logical_or(terminated1, truncated1)
index = np.where(done)[0]
self.assertTrue(np.all(~done1[index]))

def test_highlevel_step(self) -> None:
assert version.parse(gym.__version__) >= version.parse("0.26.0")
num_envs = 4
config = AtariEnvSpec.gen_config(task="pong", num_envs=num_envs)
spec = AtariEnvSpec(config)
env = AtariGymEnvPool(spec)
self.assertTrue(isinstance(env, gym.Env))
logging.info(env)
obs = env.reset()
obs, _ = env.reset()
# check shape
self.assertIsInstance(obs, np.ndarray)
self.assertEqual(obs.dtype, np.uint8) # type: ignore
obs, rew, done, info = env.step(np.random.randint(6, size=num_envs))
self.assertEqual(obs.dtype, np.uint8)
obs, rew, terminated, truncated, info = env.step(
np.random.randint(6, size=num_envs)
)
done = np.logical_or(terminated, truncated)
self.assertIsInstance(obs, np.ndarray)
self.assertEqual(obs.dtype, np.uint8)
np.testing.assert_allclose(rew.shape, (num_envs,))
self.assertEqual(rew.dtype, np.float32)
np.testing.assert_allclose(done.shape, (num_envs,))
self.assertEqual(done.dtype, np.bool_)
self.assertIsInstance(info, dict)
self.assertEqual(len(info), 7)
self.assertEqual(len(info), 6)
self.assertEqual(info["env_id"].dtype, np.int32)
self.assertEqual(info["lives"].dtype, np.int32)
self.assertEqual(info["players"]["env_id"].dtype, np.int32)
self.assertEqual(info["TimeLimit.truncated"].dtype, np.bool_)
self.assertEqual(truncated.dtype, np.bool_)
np.testing.assert_allclose(info["env_id"], np.arange(num_envs))
np.testing.assert_allclose(info["lives"].shape, (num_envs,))
np.testing.assert_allclose(info["players"]["env_id"].shape, (num_envs,))
np.testing.assert_allclose(info["TimeLimit.truncated"].shape, (num_envs,))
np.testing.assert_allclose(truncated.shape, (num_envs,))
while not np.any(done):
obs, rew, done, info = env.step(np.random.randint(6, size=num_envs))
obs1, rew1, done1, info1 = env.step(np.random.randint(6, size=num_envs))
obs, rew, terminated, truncated, info = env.step(
np.random.randint(6, size=num_envs)
)
done = np.logical_or(terminated, truncated)
obs1, rew1, terminated1, truncated1, info1 = env.step(
np.random.randint(6, size=num_envs)
)
done1 = np.logical_or(terminated1, truncated1)
index = np.where(done)[0]
self.assertTrue(np.all(~done1[index]))

Expand Down
38 changes: 22 additions & 16 deletions envpool/atari/atari_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_align(self) -> None:
spec = AtariEnvSpec(config)
env0 = AtariGymEnvPool(spec)
env1 = AtariDMEnvPool(spec)
obs0 = env0.reset()
obs0, _ = env0.reset()
obs1 = env1.reset().observation.obs # type: ignore
np.testing.assert_allclose(obs0, obs1)
for _ in range(1000):
Expand Down Expand Up @@ -96,7 +96,10 @@ def test_reset_life(self) -> None:
# no life in this game
continue
for _ in range(10000):
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
_, _, terminated, truncated, info = env.step(
np.random.randint(0, action_num, 1)
)
done = np.logical_or(terminated, truncated)
if info["lives"][0] == 0:
break
else:
Expand All @@ -106,7 +109,7 @@ def test_reset_life(self) -> None:
continue
# for normal atari (e.g., breakout)
# take an additional step after all lives are exhausted
_, _, next_done, next_info = env.step(
_, _, next_terminated, next_truncated, next_info = env.step(
np.random.randint(0, action_num, 1)
)
if done[0] and next_info["lives"][0] > 0:
Expand All @@ -116,8 +119,11 @@ def test_reset_life(self) -> None:
self.assertFalse(info["terminated"][0])
while not done[0]:
self.assertFalse(info["terminated"][0])
_, _, done, info = env.step(np.random.randint(0, action_num, 1))
_, _, next_done, next_info = env.step(
_, _, terminated, truncated, info = env.step(
np.random.randint(0, action_num, 1)
)
done = np.logical_or(terminated, truncated)
_, _, next_terminated, next_truncated, next_info = env.step(
np.random.randint(0, action_num, 1)
)
self.assertTrue(next_info["lives"][0] > 0)
Expand All @@ -137,21 +143,21 @@ def test_partial_step(self) -> None:
partial_ids = [np.arange(num_envs)[::2], np.arange(num_envs)[1::2]]
env.step(np.zeros(len(partial_ids[1]), dtype=int), env_id=partial_ids[1])
for _ in range(max_episode_steps - 2):
info = env.step(
_, _, _, truncated, info = env.step(
np.zeros(num_envs, dtype=int), env_id=np.arange(num_envs)
)[-1]
assert np.all(~info["TimeLimit.truncated"])
info = env.step(
)
assert np.all(~truncated)
_, _, _, truncated, info = env.step(
np.zeros(num_envs, dtype=int), env_id=np.arange(num_envs)
)[-1]
)
env_id = np.array(info["env_id"])
done_id = np.array(sorted(env_id[info["TimeLimit.truncated"]]))
done_id = np.array(sorted(env_id[truncated]))
assert np.all(done_id == partial_ids[1])
info = env.step(
_, _, _, truncated, info = env.step(
np.zeros(len(partial_ids[0]), dtype=int),
env_id=partial_ids[0],
)[-1]
assert np.all(info["TimeLimit.truncated"])
)
assert np.all(truncated)

def test_xla_api(self) -> None:
num_envs = 10
Expand Down Expand Up @@ -216,15 +222,15 @@ def test_no_gray_scale(self) -> None:
spec = AtariEnvSpec(config)
env = AtariGymEnvPool(spec)
self.assertTrue(env.observation_space.shape, ref_shape)
obs = env.reset()
obs, _ = env.reset()
self.assertTrue(obs.shape, ref_shape)
config = AtariEnvSpec.gen_config(
task="breakout", gray_scale=False, img_height=210, img_width=160
)
spec = AtariEnvSpec(config)
env = AtariGymEnvPool(spec)
self.assertTrue(env.observation_space.shape, raw_shape)
obs1 = env.reset()
obs1, _ = env.reset()
self.assertTrue(obs1.shape, raw_shape)
for i in range(0, 12, 3):
obs_ = cv2.resize(
Expand Down
5 changes: 3 additions & 2 deletions envpool/atari/atari_pretrain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def eval_qrdqn(
policy.eval()
ids = np.arange(num_envs)
reward = np.zeros(num_envs)
obs = env.reset()
obs, _ = env.reset()
for _ in range(25000):
if np.random.rand() < 5e-3:
act = np.random.randint(action_shape, size=len(ids))
else:
act = policy(Batch(obs=obs, info={})).act
obs, rew, done, info = env.step(act, ids)
obs, rew, terminated, truncated, info = env.step(act, ids)
done = np.logical_or(terminated, truncated)
ids = np.asarray(info["env_id"])
reward[ids] += rew
obs = obs[~done]
Expand Down
10 changes: 6 additions & 4 deletions envpool/box2d/box2d_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ def solve_lunar_lander(self, num_envs: int, continuous: bool) -> None:
for _ in range(2):
env_id = np.arange(num_envs)
done = np.array([False] * num_envs)
obs = env.reset(env_id)
obs, _ = env.reset(env_id)
rewards = np.zeros(num_envs)
while not np.all(done):
action = np.array(
[self.heuristic_lunar_lander_policy(s, continuous) for s in obs]
)
obs, rew, done, info = env.step(action, env_id)
obs, rew, terminated, truncated, info = env.step(action, env_id)
done = np.logical_or(terminated, truncated)
env_id = info["env_id"]
rewards[env_id] += rew
obs = obs[~done]
Expand Down Expand Up @@ -228,11 +229,12 @@ def solve_bipedal_walker(
)
env_id = np.arange(num_envs)
done = np.array([False] * num_envs)
obs = env.reset(env_id)
obs, _ = env.reset(env_id)
rewards = np.zeros(num_envs)
action = np.zeros([num_envs, 4])
for _ in range(max_episode_steps):
obs, rew, done, info = env.step(action, env_id)
obs, rew, terminated, truncated, info = env.step(action, env_id)
done = np.logical_or(terminated, truncated)
if render:
self.render_bpw(info)
env_id = info["env_id"]
Expand Down
10 changes: 7 additions & 3 deletions envpool/classic_control/classic_control_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,15 @@ def run_align_check(self, env0: gym.Env, env1: Any, reset_fn: Any) -> None:
d0 = False
while not d0:
a = env0.action_space.sample()
o0, r0, d0, _ = env0.step(a)
o1, r1, d1, _ = env1.step(np.array([a]), np.array([0]))
o0, r0, term0, trunc0, _ = env0.step(a)
d0 = np.logical_or(term0, trunc0)
o1, r1, term1, trunc1, _ = env1.step(np.array([a]), np.array([0]))
d1 = np.logical_or(term1, trunc1)
np.testing.assert_allclose(o0, o1[0], atol=1e-4)
np.testing.assert_allclose(r0, r1[0])
np.testing.assert_allclose(d0, d1[0])
np.testing.assert_allclose(term0, term1[0])
np.testing.assert_allclose(trunc0, trunc1[0])

def test_cartpole(self) -> None:
env0 = gym.make("CartPole-v1")
Expand All @@ -109,7 +113,7 @@ def test_mountain_car(self) -> None:
@no_type_check
def reset_fn(env0: gym.Env, env1: Any) -> None:
env0.reset()
obs = env1.reset()
obs, _ = env1.reset()
env0.unwrapped.state = obs[0]

env0 = gym.make("MountainCar-v0")
Expand Down
1 change: 0 additions & 1 deletion envpool/mujoco/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ py_test(
requirement("absl-py"),
requirement("gym"),
requirement("mujoco"),
requirement("mjc_mwe"),
],
)

Expand Down
Loading