Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Nov 29, 2024
1 parent db41372 commit d78f5a3
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
]

[tool.poetry.dependencies]
python = ">=3.7.1,<3.11"
python = ">=3.7.1,<3.14"
gym = {version = "^0.23.0", extras = ["classic_control"]}
tensorboard = "^2.8.0"
tensorboardX = "^2.5"
Expand Down
14 changes: 6 additions & 8 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rl_games.common import datasets

from torch import optim
import torch
import torch


class A2CAgent(a2c_common.ContinuousA2CBase):
Expand All @@ -30,11 +30,11 @@ def __init__(self, base_name, params):
'actions_num' : self.actions_num,
'input_shape' : obs_shape,
'num_seqs' : self.num_actors * self.num_agents,
'value_size': self.env_info.get('value_size',1),
'value_size': self.env_info.get('value_size', 1),
'normalize_value' : self.normalize_value,
'normalize_input': self.normalize_input,
}

self.model = self.network.build(build_config)
self.model.to(self.ppo_device)
self.states = None
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, base_name, params):
def update_epoch(self):
self.epoch_num += 1
return self.epoch_num

def save(self, fn):
state = self.get_full_state_weights()
torch_ext.save_checkpoint(fn, state)
Expand Down Expand Up @@ -114,7 +114,7 @@ def calc_gradients(self, input_dict):

batch_dict = {
'is_train': True,
'prev_actions': actions_batch,
'prev_actions': actions_batch,
'obs' : obs_batch,
}

Expand Down Expand Up @@ -195,7 +195,7 @@ def train_actor_critic(self, input_dict):

def reg_loss(self, mu):
if self.bounds_loss_coef is not None:
reg_loss = (mu*mu).sum(axis=-1)
reg_loss = (mu * mu).sum(axis=-1)
else:
reg_loss = 0
return reg_loss
Expand All @@ -209,5 +209,3 @@ def bound_loss(self, mu):
else:
b_loss = 0
return b_loss


72 changes: 34 additions & 38 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import math



class HCRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
Expand All @@ -34,8 +33,6 @@ def step(self, action):
return observation, reward, done, info




class DMControlObsWrapper(gym.ObservationWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
Expand Down Expand Up @@ -96,15 +93,15 @@ def create_myo(**kwargs):
def create_atari_gym_env(**kwargs):
#frames = kwargs.pop('frames', 1)
name = kwargs.pop('name')
skip = kwargs.pop('skip',4)
episode_life = kwargs.pop('episode_life',True)
skip = kwargs.pop('skip', 4)
episode_life = kwargs.pop('episode_life', True)
wrap_impala = kwargs.pop('wrap_impala', False)
env = wrappers.make_atari_deepmind(name, skip=skip,episode_life=episode_life, wrap_impala=wrap_impala, **kwargs)
return env
env = wrappers.make_atari_deepmind(name, skip=skip, episode_life=episode_life, wrap_impala=wrap_impala, **kwargs)
return env

def create_dm_control_env(**kwargs):
frames = kwargs.pop('frames', 1)
name = 'dm2gym:'+ kwargs.pop('name')
name = 'dm2gym:' + kwargs.pop('name')
env = gym.make(name, environment_kwargs=kwargs)
env = DMControlWrapper(env)
env = DMControlObsWrapper(env)
Expand Down Expand Up @@ -140,11 +137,11 @@ def create_super_mario_env_stage1(name='SuperMarioBrosRandomStage1-v1'):

env = gym_super_mario_bros.make(stage_names[1])
env = JoypadSpace(env, SIMPLE_MOVEMENT)

env = wrappers.MaxAndSkipEnv(env, skip=4)
env = wrappers.wrap_deepmind(env, episode_life=False, clip_rewards=False, frame_stack=True, scale=True)
#env = wrappers.AllowBacktracking(env)

return env

def create_quadrupped_env():
Expand All @@ -166,8 +163,7 @@ def create_smac(name, **kwargs):
has_cv = kwargs.get('central_value', False)
as_single_agent = kwargs.pop('as_single_agent', False)
env = SMACEnv(name, **kwargs)



if frames > 1:
if has_cv:
env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten)
Expand All @@ -185,7 +181,7 @@ def create_smac_v2(name, **kwargs):
flatten = kwargs.pop('flatten', True)
has_cv = kwargs.get('central_value', False)
env = SMACEnvV2(name, **kwargs)

if frames > 1:
if has_cv:
env = wrappers.BatchedFrameStackWithStates(env, frames, transpose=False, flatten=flatten)
Expand Down Expand Up @@ -217,15 +213,13 @@ def create_minigrid_env(name, **kwargs):
import gym_minigrid
import gym_minigrid.wrappers


state_bonus = kwargs.pop('state_bonus', False)
action_bonus = kwargs.pop('action_bonus', False)
rgb_fully_obs = kwargs.pop('rgb_fully_obs', False)
rgb_partial_obs = kwargs.pop('rgb_partial_obs', True)
view_size = kwargs.pop('view_size', 3)
env = gym.make(name, **kwargs)


if state_bonus:
env = gym_minigrid.wrappers.StateBonus(env)
if action_bonus:
Expand All @@ -243,7 +237,7 @@ def create_minigrid_env(name, **kwargs):

def create_multiwalker_env(**kwargs):
from rl_games.envs.multiwalker import MultiWalker
env = MultiWalker('', **kwargs)
env = MultiWalker('', **kwargs)

return env

Expand Down Expand Up @@ -290,87 +284,87 @@ def create_env(name, **kwargs):
'vecenv_type' : 'RAY'
},
'PongNoFrameskip-v4' : {
'env_creator' : lambda **kwargs : wrappers.make_atari_deepmind('PongNoFrameskip-v4', skip=4),
'env_creator' : lambda **kwargs : wrappers.make_atari_deepmind('PongNoFrameskip-v4', skip=4),
'vecenv_type' : 'RAY'
},
'BreakoutNoFrameskip-v4' : {
'env_creator' : lambda **kwargs : wrappers.make_atari_deepmind('BreakoutNoFrameskip-v4', skip=4,sticky=False),
'env_creator' : lambda **kwargs : wrappers.make_atari_deepmind('BreakoutNoFrameskip-v4', skip=4, sticky=False),
'vecenv_type' : 'RAY'
},
'MsPacmanNoFrameskip-v4' : {
'env_creator' : lambda **kwargs : wrappers.make_atari_deepmind('MsPacmanNoFrameskip-v4', skip=4),
'env_creator' : lambda **kwargs : wrappers.make_atari_deepmind('MsPacmanNoFrameskip-v4', skip=4),
'vecenv_type' : 'RAY'
},
'CarRacing-v0' : {
'env_creator' : lambda **kwargs : wrappers.make_car_racing('CarRacing-v0', skip=4),
'env_creator' : lambda **kwargs : wrappers.make_car_racing('CarRacing-v0', skip=4),
'vecenv_type' : 'RAY'
},
'RoboschoolAnt-v1' : {
'env_creator' : lambda **kwargs : create_roboschool_env('RoboschoolAnt-v1'),
'vecenv_type' : 'RAY'
},
'SuperMarioBros-v1' : {
'env_creator' : lambda : create_super_mario_env(),
'env_creator' : lambda : create_super_mario_env(),
'vecenv_type' : 'RAY'
},
'SuperMarioBrosRandomStages-v1' : {
'env_creator' : lambda : create_super_mario_env('SuperMarioBrosRandomStages-v1'),
'env_creator' : lambda : create_super_mario_env('SuperMarioBrosRandomStages-v1'),
'vecenv_type' : 'RAY'
},
'SuperMarioBrosRandomStage1-v1' : {
'env_creator' : lambda **kwargs : create_super_mario_env_stage1('SuperMarioBrosRandomStage1-v1'),
'env_creator' : lambda **kwargs : create_super_mario_env_stage1('SuperMarioBrosRandomStage1-v1'),
'vecenv_type' : 'RAY'
},
'RoboschoolHalfCheetah-v1' : {
'env_creator' : lambda **kwargs : create_roboschool_env('RoboschoolHalfCheetah-v1'),
'env_creator' : lambda **kwargs : create_roboschool_env('RoboschoolHalfCheetah-v1'),
'vecenv_type' : 'RAY'
},
'RoboschoolHumanoid-v1' : {
'env_creator' : lambda : wrappers.FrameStack(create_roboschool_env('RoboschoolHumanoid-v1'), 1, True),
'vecenv_type' : 'RAY'
},
'LunarLanderContinuous-v2' : {
'env_creator' : lambda **kwargs : gym.make('LunarLanderContinuous-v2'),
'env_creator' : lambda **kwargs : gym.make('LunarLanderContinuous-v2'),
'vecenv_type' : 'RAY'
},
'RoboschoolHumanoidFlagrun-v1' : {
'env_creator' : lambda **kwargs : wrappers.FrameStack(create_roboschool_env('RoboschoolHumanoidFlagrun-v1'), 1, True),
'env_creator' : lambda **kwargs : wrappers.FrameStack(create_roboschool_env('RoboschoolHumanoidFlagrun-v1'), 1, True),
'vecenv_type' : 'RAY'
},
'BipedalWalker-v3' : {
'env_creator' : lambda **kwargs : create_env('BipedalWalker-v3', **kwargs),
'env_creator' : lambda **kwargs : create_env('BipedalWalker-v3', **kwargs),
'vecenv_type' : 'RAY'
},
'BipedalWalkerCnn-v3' : {
'env_creator' : lambda **kwargs : wrappers.FrameStack(HCRewardEnv(gym.make('BipedalWalker-v3')), 4, False),
'env_creator' : lambda **kwargs : wrappers.FrameStack(HCRewardEnv(gym.make('BipedalWalker-v3')), 4, False),
'vecenv_type' : 'RAY'
},
'BipedalWalkerHardcore-v3' : {
'env_creator' : lambda **kwargs : gym.make('BipedalWalkerHardcore-v3'),
'env_creator' : lambda **kwargs : gym.make('BipedalWalkerHardcore-v3'),
'vecenv_type' : 'RAY'
},
'ReacherPyBulletEnv-v0' : {
'env_creator' : lambda **kwargs : create_roboschool_env('ReacherPyBulletEnv-v0'),
'env_creator' : lambda **kwargs : create_roboschool_env('ReacherPyBulletEnv-v0'),
'vecenv_type' : 'RAY'
},
'BipedalWalkerHardcoreCnn-v3' : {
'env_creator' : lambda : wrappers.FrameStack(gym.make('BipedalWalkerHardcore-v3'), 4, False),
'vecenv_type' : 'RAY'
},
'QuadruppedWalk-v1' : {
'env_creator' : lambda **kwargs : create_quadrupped_env(),
'env_creator' : lambda **kwargs : create_quadrupped_env(),
'vecenv_type' : 'RAY'
},
'FlexAnt' : {
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/ant.yaml'),
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/ant.yaml'),
'vecenv_type' : 'ISAAC'
},
'FlexHumanoid' : {
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/humanoid.yaml'),
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/humanoid.yaml'),
'vecenv_type' : 'ISAAC'
},
'FlexHumanoidHard' : {
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/humanoid_hard.yaml'),
'env_creator' : lambda **kwargs : create_flex(FLEX_PATH + '/demo/gym/cfg/humanoid_hard.yaml'),
'vecenv_type' : 'ISAAC'
},
'smac' : {
Expand Down Expand Up @@ -423,7 +417,7 @@ def create_env(name, **kwargs):
},
'brax' : {
'env_creator': lambda **kwargs: create_brax_env(**kwargs),
'vecenv_type': 'BRAX'
'vecenv_type': 'BRAX'
},
'envpool': {
'env_creator': lambda **kwargs: create_envpool(**kwargs),
Expand All @@ -439,6 +433,7 @@ def create_env(name, **kwargs):
},
}


def get_env_info(env):
result_shapes = {}
result_shapes['observation_space'] = env.observation_space
Expand All @@ -450,16 +445,17 @@ def get_env_info(env):
'''
if isinstance(result_shapes['observation_space'], gym.spaces.dict.Dict):
result_shapes['observation_space'] = observation_space['observations']
if isinstance(result_shapes['observation_space'], dict):
result_shapes['observation_space'] = observation_space['observations']
result_shapes['state_space'] = observation_space['states']
'''
if hasattr(env, "value_size"):
if hasattr(env, "value_size"):
result_shapes['value_size'] = env.value_size
print(result_shapes)
return result_shapes


def get_obs_and_action_spaces_from_config(config):
env_config = config.get('env_config', {})
env = configurations[config['env_name']]['env_creator'](**env_config)
Expand All @@ -476,4 +472,4 @@ def register(name, config):
config (:obj:`dict`): Dictionary with env type and a creator function.
"""
configurations[name] = config
configurations[name] = config
44 changes: 20 additions & 24 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Setup script for rl_games"""

import sys
import os
import pathlib

from setuptools import setup, find_packages
# The directory containing this file
HERE = pathlib.Path(__file__).parent
Expand All @@ -16,34 +13,33 @@
long_description=README,
long_description_content_type="text/markdown",
url="https://github.com/Denys88/rl_games",
#packages=[package for package in find_packages() if package.startswith('rl_games')],
packages = ['.','rl_games','docs'],
package_data={'rl_games':['*','*/*','*/*/*'],'docs':['*','*/*','*/*/*'],},
packages=['.', 'rl_games', 'docs'],
package_data={'rl_games': ['*', '*/*', '*/*/*'], 'docs': ['*', '*/*', '*/*/*'], },
version='1.6.1',
author='Denys Makoviichuk, Viktor Makoviichuk',
author_email='[email protected], [email protected]',
license="MIT",
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10"
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
#packages=["rlg"],
include_package_data=True,
install_requires=[
# this setup is only for pytorch
#
'gym>=0.17.2',
'torch>=1.7.0',
'numpy>=1.16.0',
'tensorboard>=1.14.0',
'tensorboardX>=1.6',
'setproctitle',
'psutil',
'pyyaml',
'watchdog>=2.1.9,<3.0.0', # for evaluation process
'gym>=0.17.2',
'torch>=2.0.0',
'numpy>=1.16.0',
'tensorboard>=1.14.0',
'tensorboardX>=1.6',
'setproctitle',
'psutil',
'pyyaml',
'watchdog>=2.1.9', # for evaluation process
],
)

0 comments on commit d78f5a3

Please sign in to comment.