Skip to content

Commit

Permalink
Expose FairseqOptimizer.param_groups property
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed Jul 17, 2020
1 parent d15829e commit 8340b2d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions fairseq/optim/fairseq_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,24 @@ def optimizer_config(self):
@property
def params(self):
"""Return an iterable of the parameters held by the optimizer."""
for param_group in self.optimizer.param_groups:
for param_group in self.param_groups:
for p in param_group['params']:
yield p

@property
def param_groups(self):
return self.optimizer.param_groups

def __getstate__(self):
return self._optimizer.__getstate__()

def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]['lr']
return self.param_groups[0]['lr']

def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
for param_group in self.param_groups:
param_group['lr'] = lr

def state_dict(self):
Expand All @@ -73,7 +77,7 @@ def load_state_dict(self, state_dict, optimizer_overrides=None):

if optimizer_overrides is not None and len(optimizer_overrides) > 0:
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
for group in self.param_groups:
group.update(optimizer_overrides)

def backward(self, loss):
Expand Down

0 comments on commit 8340b2d

Please sign in to comment.