Skip to content

Commit

Permalink
Improved error handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Dec 6, 2024
1 parent 65c86c8 commit 249b1f0
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 18 deletions.
2 changes: 1 addition & 1 deletion rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def restore_central_value_function(self, fn):
self.set_central_value_function_weights(checkpoint)

def get_masked_action_values(self, obs, action_masks):
assert False
raise NotImplementedError("Masked action values are not implemented for continuous actions")

def calc_gradients(self, input_dict):
"""Compute gradients needed to step the networks of the algorithm.
Expand Down
3 changes: 1 addition & 2 deletions rl_games/algos_torch/moving_mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import rl_games.algos_torch.torch_ext as torch_ext


'''
updates moving statistics with momentum
'''
Expand Down Expand Up @@ -76,7 +77,6 @@ def _get_stats(self):
else:
raise NotImplementedError(self.impl)


def _update_stats(self, x):
m = self.decay
if self.impl == 'off':
Expand Down Expand Up @@ -108,7 +108,6 @@ def forward(self, input, mask=None, denorm=False):
self._update_stats(input)

offset, invscale = self._get_stats()

if denorm:
y = input * invscale + offset
else:
Expand Down
8 changes: 5 additions & 3 deletions rl_games/algos_torch/sac_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def set_full_state_weights(self, weights, set_epoch=True):
self.vec_env.set_env_state(env_state)

def restore(self, fn, set_epoch=True):
print("SAC restore")
if not os.path.exists(fn):
raise FileNotFoundError(f"Checkpoint file not found: {fn}")
checkpoint = torch_ext.load_checkpoint(fn)
self.set_full_state_weights(checkpoint, set_epoch=set_epoch)

Expand All @@ -268,7 +269,7 @@ def set_param(self, param_name, param_value):
pass

def get_masked_action_values(self, obs, action_masks):
assert False
raise NotImplementedError("Masked action values are not supported in SAC agent")

def set_eval(self):
self.model.eval()
Expand Down Expand Up @@ -425,7 +426,8 @@ def act(self, obs, action_dim, sample=False):

actions = dist.sample() if sample else dist.mean
actions = actions.clamp(*self.action_range)
assert actions.ndim == 2
if actions.ndim != 2:
raise ValueError(f"Actions tensor must be 2-dimensional, got shape {actions.shape}")

return actions

Expand Down
3 changes: 2 additions & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ def init_tensors(self):

total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_length
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
if not ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0):
raise ValueError(f"Horizon length ({self.horizon_length}) times total agents ({total_agents}) divided by num minibatches ({self.num_minibatches}) must be divisible by sequence length ({self.seq_length})")
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

def init_current_rewards(self, batch_size, current_rewards_shape):
Expand Down
7 changes: 2 additions & 5 deletions rl_games/common/diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import rl_games.algos_torch.torch_ext as torch_ext


class DefaultDiagnostics(object):
def __init__(self):
pass
Expand All @@ -24,7 +25,7 @@ def __init__(self):
def send_info(self, writter):
if writter is None:
return
for k,v in self.diag_dict.items():
for k, v in self.diag_dict.items():
writter.add_scalar(k, v.cpu().numpy(), self.current_epoch)

def epoch(self, agent, current_epoch):
Expand Down Expand Up @@ -58,7 +59,3 @@ def mini_batch(self, agent, batch, e_clip, minibatch):
clip_frac = torch_ext.policy_clip_fraction(new_neglogp, old_neglogp, e_clip, masks)
self.exp_vars.append(exp_var)
self.clip_fracs.append(clip_frac)




2 changes: 1 addition & 1 deletion rl_games/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def update_data(self, name, index, val):

def update_data_rnn(self, name, indices, play_mask, val):
if type(val) is dict:
for k,v in val:
for k, v in val:
self.tensor_dict[name][k][indices, play_mask] = v
else:
self.tensor_dict[name][indices, play_mask] = val
Expand Down
10 changes: 6 additions & 4 deletions rl_games/envs/test/test_asymmetric_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
import numpy as np
from rl_games.common.wrappers import MaskVelocityWrapper


class TestAsymmetricCritic(gym.Env):
def __init__(self, wrapped_env_name, **kwargs):
gym.Env.__init__(self)
self.apply_mask = kwargs.pop('apply_mask', True)
self.use_central_value = kwargs.pop('use_central_value', True)
self.env = gym.make(wrapped_env_name)

if self.apply_mask:
if wrapped_env_name not in ["CartPole-v1", "Pendulum-v0", "LunarLander-v2", "LunarLanderContinuous-v2"]:
raise 'unsupported env'
supported_envs = ["CartPole-v1", "Pendulum-v0", "LunarLander-v2", "LunarLanderContinuous-v2"]
if wrapped_env_name not in supported_envs:
raise ValueError(f"Environment {wrapped_env_name} not supported. Supported environments: {supported_envs}")
self.mask = MaskVelocityWrapper(self.env, wrapped_env_name).mask
else:
self.mask = 1
Expand Down Expand Up @@ -47,6 +49,6 @@ def step(self, actions):
else:
obses = obs_dict["obs"].astype(np.float32)
return obses, rewards, dones, info

def has_action_mask(self):
return False
3 changes: 2 additions & 1 deletion rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
def _restore(agent, args):
if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='':
if args['train'] and args.get('load_critic_only', False):
assert agent.has_central_value, 'This should only work for asymmetric actor critic'
if not hasattr(agent, 'has_central_value') or not agent.has_central_value:
raise AttributeError('Loading critic only works only for asymmetric actor critic')
agent.restore_central_value_function(args['checkpoint'])
return
agent.restore(args['checkpoint'])
Expand Down

0 comments on commit 249b1f0

Please sign in to comment.