diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 333e531a..7dcfc571 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -92,7 +92,7 @@ def restore_central_value_function(self, fn): self.set_central_value_function_weights(checkpoint) def get_masked_action_values(self, obs, action_masks): - assert False + raise NotImplementedError("Masked action values are not implemented for continuous actions") def calc_gradients(self, input_dict): """Compute gradients needed to step the networks of the algorithm. diff --git a/rl_games/algos_torch/moving_mean_std.py b/rl_games/algos_torch/moving_mean_std.py index 363da8f4..bd2ab0b3 100644 --- a/rl_games/algos_torch/moving_mean_std.py +++ b/rl_games/algos_torch/moving_mean_std.py @@ -3,6 +3,7 @@ import numpy as np import rl_games.algos_torch.torch_ext as torch_ext + ''' updates moving statistics with momentum ''' @@ -76,7 +77,6 @@ def _get_stats(self): else: raise NotImplementedError(self.impl) - def _update_stats(self, x): m = self.decay if self.impl == 'off': @@ -108,7 +108,6 @@ def forward(self, input, mask=None, denorm=False): self._update_stats(input) offset, invscale = self._get_stats() - if denorm: y = input * invscale + offset else: diff --git a/rl_games/algos_torch/sac_agent.py b/rl_games/algos_torch/sac_agent.py index 26904127..06b94c12 100644 --- a/rl_games/algos_torch/sac_agent.py +++ b/rl_games/algos_torch/sac_agent.py @@ -257,7 +257,8 @@ def set_full_state_weights(self, weights, set_epoch=True): self.vec_env.set_env_state(env_state) def restore(self, fn, set_epoch=True): - print("SAC restore") + if not os.path.exists(fn): + raise FileNotFoundError(f"Checkpoint file not found: {fn}") checkpoint = torch_ext.load_checkpoint(fn) self.set_full_state_weights(checkpoint, set_epoch=set_epoch) @@ -268,7 +269,7 @@ def set_param(self, param_name, param_value): pass def get_masked_action_values(self, obs, action_masks): - assert False + raise NotImplementedError("Masked action values are not supported in SAC agent") def set_eval(self): self.model.eval() @@ -425,7 +426,8 @@ def act(self, obs, action_dim, sample=False): actions = dist.sample() if sample else dist.mean actions = actions.clamp(*self.action_range) - assert actions.ndim == 2 + if actions.ndim != 2: + raise ValueError(f"Actions tensor must be 2-dimensional, got shape {actions.shape}") return actions diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 2a2d78a6..24736e87 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -484,7 +484,8 @@ def init_tensors(self): total_agents = self.num_agents * self.num_actors num_seqs = self.horizon_length // self.seq_length - assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) + if not ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0): + raise ValueError(f"Horizon length ({self.horizon_length}) times total agents ({total_agents}) divided by num minibatches ({self.num_minibatches}) must be divisible by sequence length ({self.seq_length})") self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] def init_current_rewards(self, batch_size, current_rewards_shape): diff --git a/rl_games/common/diagnostics.py b/rl_games/common/diagnostics.py index 18297460..504643c7 100644 --- a/rl_games/common/diagnostics.py +++ b/rl_games/common/diagnostics.py @@ -1,6 +1,7 @@ import torch import rl_games.algos_torch.torch_ext as torch_ext + class DefaultDiagnostics(object): def __init__(self): pass @@ -24,7 +25,7 @@ def __init__(self): def send_info(self, writter): if writter is None: return - for k,v in self.diag_dict.items(): + for k, v in self.diag_dict.items(): writter.add_scalar(k, v.cpu().numpy(), self.current_epoch) def epoch(self, agent, current_epoch): @@ -58,7 +59,3 @@ def mini_batch(self, agent, batch, e_clip, minibatch): clip_frac = torch_ext.policy_clip_fraction(new_neglogp, old_neglogp, e_clip, masks) self.exp_vars.append(exp_var) self.clip_fracs.append(clip_frac) - - - - diff --git a/rl_games/common/experience.py b/rl_games/common/experience.py index 8be1da35..5310bef0 100644 --- a/rl_games/common/experience.py +++ b/rl_games/common/experience.py @@ -378,7 +378,7 @@ def update_data(self, name, index, val): def update_data_rnn(self, name, indices, play_mask, val): if type(val) is dict: - for k,v in val: + for k, v in val: self.tensor_dict[name][k][indices, play_mask] = v else: self.tensor_dict[name][indices, play_mask] = val diff --git a/rl_games/envs/test/test_asymmetric_env.py b/rl_games/envs/test/test_asymmetric_env.py index 6fec93ed..e231447e 100644 --- a/rl_games/envs/test/test_asymmetric_env.py +++ b/rl_games/envs/test/test_asymmetric_env.py @@ -2,16 +2,18 @@ import numpy as np from rl_games.common.wrappers import MaskVelocityWrapper + class TestAsymmetricCritic(gym.Env): def __init__(self, wrapped_env_name, **kwargs): gym.Env.__init__(self) self.apply_mask = kwargs.pop('apply_mask', True) self.use_central_value = kwargs.pop('use_central_value', True) self.env = gym.make(wrapped_env_name) - + if self.apply_mask: - if wrapped_env_name not in ["CartPole-v1", "Pendulum-v0", "LunarLander-v2", "LunarLanderContinuous-v2"]: - raise 'unsupported env' + supported_envs = ["CartPole-v1", "Pendulum-v0", "LunarLander-v2", "LunarLanderContinuous-v2"] + if wrapped_env_name not in supported_envs: + raise ValueError(f"Environment {wrapped_env_name} not supported. Supported environments: {supported_envs}") self.mask = MaskVelocityWrapper(self.env, wrapped_env_name).mask else: self.mask = 1 @@ -47,6 +49,6 @@ def step(self, actions): else: obses = obs_dict["obs"].astype(np.float32) return obses, rewards, dones, info - + def has_action_mask(self): return False diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 0f7a9ac8..7a56ca1b 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -18,7 +18,8 @@ def _restore(agent, args): if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='': if args['train'] and args.get('load_critic_only', False): - assert agent.has_central_value, 'This should only work for asymmetric actor critic' + if not hasattr(agent, 'has_central_value') or not agent.has_central_value: + raise AttributeError('Loading critic only works only for asymmetric actor critic') agent.restore_central_value_function(args['checkpoint']) return agent.restore(args['checkpoint'])