Skip to content

Commit

Permalink
Fixed device RNN reset issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Dec 1, 2024
1 parent 04f72d9 commit 7ea2ffb
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 97 deletions.
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rl_games"
version = "1.6.1"
version = "1.6.5"
description = ""
readme = "README.md"
authors = [
Expand All @@ -9,7 +9,7 @@ authors = [
]

[tool.poetry.dependencies]
python = ">=3.7.1,<3.11"
python = ">=3.7.1"
gym = {version = "^0.23.0", extras = ["classic_control"]}
tensorboard = "^2.8.0"
tensorboardX = "^2.5"
Expand Down
1 change: 0 additions & 1 deletion rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng

if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
total_agents = self.num_actors #* self.num_agents
num_seqs = self.horizon_length // self.seq_length
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
Expand Down
122 changes: 63 additions & 59 deletions rl_games/algos_torch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.distributions import Normal, TransformedDistribution, TanhTransform
import math


class BaseModel():
def __init__(self, model_class):
self.model_class = model_class
Expand All @@ -33,6 +34,7 @@ def build(self, config):
return self.Network(self.network_builder.build(self.model_class, **config), obs_shape=obs_shape,
normalize_value=normalize_value, normalize_input=normalize_input, value_size=value_size)


class BaseModelNetwork(nn.Module):
def __init__(self, obs_shape, normalize_value, normalize_input, value_size):
nn.Module.__init__(self)
Expand All @@ -56,29 +58,29 @@ def norm_obs(self, observation):
def denorm_value(self, value):
with torch.no_grad():
return self.value_mean_std(value, denorm=True) if self.normalize_value else value



def get_aux_loss(self):
return None


class ModelA2C(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self,**kwargs)
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()
return self.a2c_network.get_default_rnn_state()

def get_value_layer(self):
return self.a2c_network.get_value_layer()
Expand All @@ -100,26 +102,27 @@ def forward(self, input_dict):
prev_neglogp = -categorical.log_prob(prev_actions)
entropy = categorical.entropy()
result = {
'prev_neglogp' : torch.squeeze(prev_neglogp),
'logits' : categorical.logits,
'values' : value,
'entropy' : entropy,
'rnn_states' : states
'prev_neglogp': torch.squeeze(prev_neglogp),
'logits': categorical.logits,
'values': value,
'entropy': entropy,
'rnn_states': states
}
return result
else:
categorical = CategoricalMasked(logits=logits, masks=action_masks)
selected_action = categorical.sample().long()
neglogp = -categorical.log_prob(selected_action)
result = {
'neglogpacs' : torch.squeeze(neglogp),
'values' : self.denorm_value(value),
'actions' : selected_action,
'logits' : categorical.logits,
'rnn_states' : states
'neglogpacs': torch.squeeze(neglogp),
'values': self.denorm_value(value),
'actions': selected_action,
'logits': categorical.logits,
'rnn_states': states
}
return result


class ModelA2CMultiDiscrete(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand All @@ -132,10 +135,10 @@ def __init__(self, a2c_network, **kwargs):

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

Expand All @@ -160,16 +163,16 @@ def forward(self, input_dict):
action_masks = np.split(action_masks,len(logits), axis=1)
categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)]
prev_actions = torch.split(prev_actions, 1, dim=-1)
prev_neglogp = [-c.log_prob(a.squeeze()) for c,a in zip(categorical, prev_actions)]
prev_neglogp = [-c.log_prob(a.squeeze()) for c, a in zip(categorical, prev_actions)]
prev_neglogp = torch.stack(prev_neglogp, dim=-1).sum(dim=-1)
entropy = [c.entropy() for c in categorical]
entropy = torch.stack(entropy, dim=-1).sum(dim=-1)
result = {
'prev_neglogp' : torch.squeeze(prev_neglogp),
'logits' : [c.logits for c in categorical],
'values' : value,
'entropy' : torch.squeeze(entropy),
'rnn_states' : states
'prev_neglogp': torch.squeeze(prev_neglogp),
'logits': [c.logits for c in categorical],
'values': value,
'entropy': torch.squeeze(entropy),
'rnn_states': states
}
return result
else:
Expand All @@ -178,20 +181,21 @@ def forward(self, input_dict):
else:
action_masks = np.split(action_masks, len(logits), axis=1)
categorical = [CategoricalMasked(logits=logit, masks=mask) for logit, mask in zip(logits, action_masks)]

selected_action = [c.sample().long() for c in categorical]
neglogp = [-c.log_prob(a.squeeze()) for c,a in zip(categorical, selected_action)]
neglogp = [-c.log_prob(a.squeeze()) for c, a in zip(categorical, selected_action)]
selected_action = torch.stack(selected_action, dim=-1)
neglogp = torch.stack(neglogp, dim=-1).sum(dim=-1)
result = {
'neglogpacs' : torch.squeeze(neglogp),
'values' : self.denorm_value(value),
'actions' : selected_action,
'logits' : [c.logits for c in categorical],
'rnn_states' : states
'neglogpacs': torch.squeeze(neglogp),
'values': self.denorm_value(value),
'actions': selected_action,
'logits': [c.logits for c in categorical],
'rnn_states': states
}
return result


class ModelA2CContinuous(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
Expand All @@ -204,10 +208,10 @@ def __init__(self, a2c_network, **kwargs):

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

def get_default_rnn_state(self):
return self.a2c_network.get_default_rnn_state()

Expand All @@ -230,28 +234,27 @@ def forward(self, input_dict):
entropy = distr.entropy().sum(dim=-1)
prev_neglogp = -distr.log_prob(prev_actions).sum(dim=-1)
result = {
'prev_neglogp' : torch.squeeze(prev_neglogp),
'value' : value,
'entropy' : entropy,
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
'prev_neglogp': torch.squeeze(prev_neglogp),
'value': value,
'entropy': entropy,
'rnn_states': states,
'mus': mu,
'sigmas': sigma
}
return result
else:
selected_action = distr.sample().squeeze()
neglogp = -distr.log_prob(selected_action).sum(dim=-1)
result = {
'neglogpacs' : torch.squeeze(neglogp),
'values' : self.denorm_value(value),
'actions' : selected_action,
'entropy' : entropy,
'rnn_states' : states,
'mus' : mu,
'sigmas' : sigma
'neglogpacs': torch.squeeze(neglogp),
'values': self.denorm_value(value),
'actions': selected_action,
'entropy': entropy,
'rnn_states': states,
'mus': mu,
'sigmas': sigma
}
return result

return result


class ModelA2CContinuousLogStd(BaseModel):
Expand All @@ -266,7 +269,7 @@ def __init__(self, a2c_network, **kwargs):

def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -313,18 +316,19 @@ def neglogp(self, x, mean, std, logstd):
+ 0.5 * np.log(2.0 * np.pi) * x.size()[-1] \
+ logstd.sum(dim=-1)


class ModelA2CContinuousTanh(BaseModel):
def __init__(self, network):
BaseModel.__init__(self, 'a2c')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, a2c_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.a2c_network = a2c_network
def get_aux_loss(self):
return self.a2c_network.get_aux_loss()

def is_rnn(self):
return self.a2c_network.is_rnn()

Expand Down Expand Up @@ -407,16 +411,15 @@ def forward(self, input_dict):
return result



class ModelSACContinuous(BaseModel):

def __init__(self, network):
BaseModel.__init__(self, 'sac')
self.network_builder = network

class Network(BaseModelNetwork):
def __init__(self, sac_network,**kwargs):
BaseModelNetwork.__init__(self,**kwargs)
def __init__(self, sac_network, **kwargs):
BaseModelNetwork.__init__(self, **kwargs)
self.sac_network = sac_network

def get_aux_loss(self):
Expand All @@ -430,7 +433,7 @@ def critic_target(self, obs, action):

def actor(self, obs):
return self.sac_network.actor(obs)

def is_rnn(self):
return False

Expand All @@ -455,6 +458,7 @@ def forward_log_det_jacobian(self, x):
# Log of the absolute value of the determinant of the Jacobian
return 2. * (math.log(2.) - x - F.softplus(-2. * x))


class NormalTanhDistribution:
"""Normal distribution followed by tanh."""

Expand Down Expand Up @@ -488,11 +492,11 @@ def sample(self, loc, scale):
def post_process(self, pre_tanh_sample):
"""Returns a postprocessed sample."""
return self._postprocessor.forward(pre_tanh_sample)

def inverse_post_process(self, post_tanh_sample):
"""Returns a postprocessed sample."""
return self._postprocessor.inverse(post_tanh_sample)

def mode(self, loc, scale):
"""Returns the mode of the postprocessed distribution."""
dist = self.create_dist(loc, scale)
Expand Down
Loading

0 comments on commit 7ea2ffb

Please sign in to comment.