diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 3242a92a35..b1b9c76edb 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -41,20 +41,24 @@ def optimizer_config(self): @property def params(self): """Return an iterable of the parameters held by the optimizer.""" - for param_group in self.optimizer.param_groups: + for param_group in self.param_groups: for p in param_group['params']: yield p + @property + def param_groups(self): + return self.optimizer.param_groups + def __getstate__(self): return self._optimizer.__getstate__() def get_lr(self): """Return the current learning rate.""" - return self.optimizer.param_groups[0]['lr'] + return self.param_groups[0]['lr'] def set_lr(self, lr): """Set the learning rate.""" - for param_group in self.optimizer.param_groups: + for param_group in self.param_groups: param_group['lr'] = lr def state_dict(self): @@ -73,7 +77,7 @@ def load_state_dict(self, state_dict, optimizer_overrides=None): if optimizer_overrides is not None and len(optimizer_overrides) > 0: # override learning rate, momentum, etc. with latest values - for group in self.optimizer.param_groups: + for group in self.param_groups: group.update(optimizer_overrides) def backward(self, loss):