Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702880644
Change-Id: Ibbb6a32ab1b0b5de6223f4044ca89c0c6c5bb7cc
  • Loading branch information
Brax Team authored and btaba committed Dec 4, 2024
1 parent fb553a7 commit 18d6a0a
Show file tree
Hide file tree
Showing 63 changed files with 1,027 additions and 595 deletions.
1 change: 1 addition & 0 deletions brax/actuator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

def _actuator_step(pipeline, sys, q, qd, act, dt, n):
sys = sys.tree_replace({'opt.timestep': dt})

def f(state, _):
return jax.jit(pipeline.step)(sys, state, act), None

Expand Down
4 changes: 2 additions & 2 deletions brax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ class Inertia(Base):
"""Angular inertia, mass, and center of mass location.
Attributes:
transform: transform for the inertial frame relative to the link frame
(i.e. center of mass position and orientation)
transform: transform for the inertial frame relative to the link frame (i.e.
center of mass position and orientation)
i: (3, 3) inertia matrix about a point P
mass: scalar mass
"""
Expand Down
1 change: 1 addition & 0 deletions brax/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Helper functions for physics calculations in maximal coordinates."""

# pylint:disable=g-multiple-import
from typing import Tuple

Expand Down
1 change: 1 addition & 0 deletions brax/com_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Tests for com."""

# pylint:disable=g-multiple-import
from absl.testing import absltest
from brax import com
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_velocity = (
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_velocity = (
Expand Down
3 changes: 2 additions & 1 deletion brax/envs/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def _get_obs(
com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd)
self.sys, action, pipeline_state.q, pipeline_state.qd
)

# external_contact_forces are excluded
return jp.concatenate([
Expand Down
3 changes: 2 additions & 1 deletion brax/envs/humanoidstandup.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def _get_obs(
com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd)
self.sys, action, pipeline_state.q, pipeline_state.qd
)

# external_contact_forces are excluded
return jp.concatenate([
Expand Down
19 changes: 8 additions & 11 deletions brax/envs/inverted_double_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ class InvertedDoublePendulum(PipelineEnv):

def __init__(self, backend='generalized', **kwargs):
path = (
epath.resource_path('brax')
/ 'envs/assets/inverted_double_pendulum.xml'
epath.resource_path('brax') / 'envs/assets/inverted_double_pendulum.xml'
)
sys = mjcf.load(path)

Expand Down Expand Up @@ -176,12 +175,10 @@ def action_size(self):

def _get_obs(self, pipeline_sate: base.State) -> jax.Array:
"""Observe cartpole body position and velocities."""
return jp.concatenate(
[
pipeline_sate.q[:1], # cart x pos
jp.sin(pipeline_sate.q[1:]),
jp.cos(pipeline_sate.q[1:]),
jp.clip(pipeline_sate.qd, -10, 10),
# qfrc_constraint is not added
]
)
return jp.concatenate([
pipeline_sate.q[:1], # cart x pos
jp.sin(pipeline_sate.q[1:]),
jp.cos(pipeline_sate.q[1:]),
jp.clip(pipeline_sate.qd, -10, 10),
# qfrc_constraint is not added
])
22 changes: 13 additions & 9 deletions brax/envs/swimmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ class Swimmer(PipelineEnv):
# pyformat: enable


def __init__(self,
forward_reward_weight=1.0,
ctrl_cost_weight=1e-4,
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
backend='generalized',
**kwargs):
def __init__(
self,
forward_reward_weight=1.0,
ctrl_cost_weight=1e-4,
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
backend='generalized',
**kwargs,
):
path = epath.resource_path('brax') / 'envs/assets/swimmer.xml'
sys = mjcf.load(path)

Expand All @@ -130,7 +132,8 @@ def __init__(self,
self._ctrl_cost_weight = ctrl_cost_weight
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation)
exclude_current_positions_from_observation
)

def reset(self, rng: jax.Array) -> State:
rng, rng1, rng2 = jax.random.split(rng, 3)
Expand All @@ -157,7 +160,8 @@ def step(self, state: State, action: jax.Array) -> State:

if pipeline_state0 is None:
raise AssertionError(
'Cannot compute rewards with pipeline_state0 as Nonetype.')
'Cannot compute rewards with pipeline_state0 as Nonetype.'
)

xy_position = pipeline_state.q[:2]

Expand Down
2 changes: 1 addition & 1 deletion brax/envs/walker2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
assert pipeline_state0 is not None
assert pipeline_state0 is not None
pipeline_state = self.pipeline_step(pipeline_state0, action)

x_velocity = (
Expand Down
51 changes: 30 additions & 21 deletions brax/envs/wrappers/dm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Wrappers to convert brax envs to DM Env envs."""

from typing import Optional

from brax.envs.base import PipelineEnv
Expand All @@ -27,10 +28,9 @@
class DmEnvWrapper(dm_env.Environment):
"""A wrapper that converts Brax Env to one that follows Dm Env API."""

def __init__(self,
env: PipelineEnv,
seed: int = 0,
backend: Optional[str] = None):
def __init__(
self, env: PipelineEnv, seed: int = 0, backend: Optional[str] = None
):
self._env = env
self.seed(seed)
self.backend = backend
Expand All @@ -40,25 +40,32 @@ def __init__(self,
self._observation_spec = self._env.observation_spec()
else:
obs_high = jp.inf * jp.ones(self._env.observation_size, dtype='float32')
self._observation_spec = specs.BoundedArray((self._env.observation_size,),
minimum=-obs_high,
maximum=obs_high,
dtype='float32',
name='observation')
self._observation_spec = specs.BoundedArray(
(self._env.observation_size,),
minimum=-obs_high,
maximum=obs_high,
dtype='float32',
name='observation',
)

if hasattr(self._env, 'action_spec'):
self._action_spec = self._env.action_spec()
else:
action = jax.tree.map(np.array, self._env.sys.actuator.ctrl_range)
self._action_spec = specs.BoundedArray((self._env.action_size,),
minimum=action[:, 0],
maximum=action[:, 1],
dtype='float32',
name='action')

self._reward_spec = specs.Array(shape=(), dtype=jp.dtype('float32'), name='reward')
self._action_spec = specs.BoundedArray(
(self._env.action_size,),
minimum=action[:, 0],
maximum=action[:, 1],
dtype='float32',
name='action',
)

self._reward_spec = specs.Array(
shape=(), dtype=jp.dtype('float32'), name='reward'
)
self._discount_spec = specs.BoundedArray(
shape=(), dtype='float32', minimum=0., maximum=1., name='discount')
shape=(), dtype='float32', minimum=0.0, maximum=1.0, name='discount'
)
if hasattr(self._env, 'discount_spec'):
self._discount_spec = self._env.discount_spec()

Expand All @@ -81,17 +88,19 @@ def reset(self):
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=jp.float32(1.),
observation=obs)
discount=jp.float32(1.0),
observation=obs,
)

def step(self, action):
self._state, obs, reward, done, info = self._step(self._state, action)
del info
return dm_env.TimeStep(
step_type=dm_env.StepType.MID if not done else dm_env.StepType.LAST,
reward=reward,
discount=jp.float32(1.),
observation=obs)
discount=jp.float32(1.0),
observation=obs,
)

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)
Expand Down
6 changes: 4 additions & 2 deletions brax/envs/wrappers/dm_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ def test_action_space(self):
base_env = envs.create('pusher')
env = dm_env.DmEnvWrapper(base_env)
np.testing.assert_array_equal(
env.action_spec().minimum, base_env.sys.actuator.ctrl_range[:, 0])
env.action_spec().minimum, base_env.sys.actuator.ctrl_range[:, 0]
)
np.testing.assert_array_equal(
env.action_spec().maximum, base_env.sys.actuator.ctrl_range[:, 1])
env.action_spec().maximum, base_env.sys.actuator.ctrl_range[:, 1]
)


if __name__ == '__main__':
Expand Down
19 changes: 9 additions & 10 deletions brax/envs/wrappers/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Wrappers to convert brax envs to gym envs."""

from typing import ClassVar, Optional

from brax.envs.base import PipelineEnv
Expand All @@ -31,14 +32,13 @@ class GymWrapper(gym.Env):
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(self,
env: PipelineEnv,
seed: int = 0,
backend: Optional[str] = None):
def __init__(
self, env: PipelineEnv, seed: int = 0, backend: Optional[str] = None
):
self._env = env
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
'video.frames_per_second': 1 / self._env.dt,
}
self.seed(seed)
self.backend = backend
Expand Down Expand Up @@ -94,14 +94,13 @@ class VectorGymWrapper(gym.vector.VectorEnv):
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(self,
env: PipelineEnv,
seed: int = 0,
backend: Optional[str] = None):
def __init__(
self, env: PipelineEnv, seed: int = 0, backend: Optional[str] = None
):
self._env = env
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
'video.frames_per_second': 1 / self._env.dt,
}
if not hasattr(self._env, 'batch_size'):
raise ValueError('underlying env must be batched')
Expand Down
12 changes: 8 additions & 4 deletions brax/envs/wrappers/gym_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def test_action_space(self):
base_env = envs.create('pusher')
env = gym.GymWrapper(base_env)
np.testing.assert_array_equal(
env.action_space.low, base_env.sys.actuator.ctrl_range[:, 0])
env.action_space.low, base_env.sys.actuator.ctrl_range[:, 0]
)
np.testing.assert_array_equal(
env.action_space.high, base_env.sys.actuator.ctrl_range[:, 1])
env.action_space.high, base_env.sys.actuator.ctrl_range[:, 1]
)


def test_vector_action_space(self):
Expand All @@ -39,10 +41,12 @@ def test_vector_action_space(self):
env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=256))
np.testing.assert_array_equal(
env.action_space.low,
np.tile(base_env.sys.actuator.ctrl_range[:, 0], [256, 1]))
np.tile(base_env.sys.actuator.ctrl_range[:, 0], [256, 1]),
)
np.testing.assert_array_equal(
env.action_space.high,
np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1]))
np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1]),
)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions brax/envs/wrappers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
This conversion happens directly on-device, without moving values to the CPU.
"""

from typing import Optional

# NOTE: The following line will emit a warning and raise ImportError if `torch`
Expand Down
1 change: 1 addition & 0 deletions brax/envs/wrappers/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Tests for training wrappers."""

import functools

from absl.testing import absltest
Expand Down
1 change: 1 addition & 0 deletions brax/generalized/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def point_jacobian(
Returns:
pt: point jacobian
"""

# backward scan up tree: build the link mask corresponding to link_idx
def mask_fn(mask_child, link):
mask = link == link_idx
Expand Down
12 changes: 6 additions & 6 deletions brax/generalized/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def transform_com(sys: System, state: State) -> State:
cinr = x_i.replace(pos=x_i.pos - root_com).vmap().do(sys.link.inertia)

# motion dofs to global frame centered at subtree-CoM
parent_idx = jp.array(
[
i if t == 'f' else p
for i, (t, p) in enumerate(zip(sys.link_types, sys.link_parents))
]
)
parent_idx = jp.array([
i if t == 'f' else p
for i, (t, p) in enumerate(zip(sys.link_types, sys.link_parents))
])
parent = state.x.concatenate(Transform.zero(shape=(1,))).take(parent_idx)
j = parent.vmap().do(sys.link.transform).vmap().do(sys.link.joint)

Expand Down Expand Up @@ -150,6 +148,7 @@ def inverse(sys: System, state: State) -> jax.Array:
Returns:
tau: generalized forces resulting from joint positions and velocities
"""

# forward scan over tree: accumulate link center of mass acceleration
def cdd_fn(cdd_parent, cdofd, qd, dof_idx):
if cdd_parent is None:
Expand Down Expand Up @@ -187,6 +186,7 @@ def cfrc_fn(cfrc_child, cfrc):

def _passive(sys: System, state: State) -> jax.Array:
"""Calculates the system's passive forces given input motion and position."""

def stiffness_fn(typ, q, dof):
if typ in 'fb':
return jp.zeros_like(dof.stiffness)
Expand Down
Loading

0 comments on commit 18d6a0a

Please sign in to comment.