diff --git a/python/mxnet/model.py b/python/mxnet/model.py index efb51096c368..f44ff041e35d 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -884,6 +884,8 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', rescale_grad=(1.0/batch_size), **(self.kwargs)) elif isinstance(self.optimizer, opt.Optimizer): + if not optimizer.idx2name: + optimizer.idx2name = param_idx2name.copy() optimizer = self.optimizer # do training diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index a7d3336e8439..e83751d42974 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -505,14 +505,14 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', batch_size *= kvstore.num_workers rescale_grad = 1.0/batch_size + idx2name = {} + if update_on_kvstore: + idx2name.update(enumerate(self._exec_group.param_names)) + else: + for k in range(len(self._context)): + idx2name.update({i*len(self._context)+k: n + for i, n in enumerate(self._exec_group.param_names)}) if isinstance(optimizer, str): - idx2name = {} - if update_on_kvstore: - idx2name.update(enumerate(self._exec_group.param_names)) - else: - for k in range(len(self._context)): - idx2name.update({i*len(self._context)+k: n - for i, n in enumerate(self._exec_group.param_names)}) optimizer_params = dict(optimizer_params) if 'rescale_grad' not in optimizer_params: optimizer_params['rescale_grad'] = rescale_grad @@ -528,6 +528,8 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', "is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%( optimizer.rescale_grad, rescale_grad) + "Is this intended?", stacklevel=2) + if not optimizer.idx2name: + optimizer.idx2name = idx2name.copy() self._optimizer = optimizer self._kvstore = kvstore diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 36c1993bf0ff..c82afdfe033a 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -931,6 +931,34 @@ def test_module_update_no_pragram(): mod.update() assert(mod.get_outputs()[0].shape == data_shape) + +def test_module_init_optimizer(): + def get_module_idx2name(mod): + idx2name = {} + idx2name.update(enumerate(mod._exec_group.param_names)) + return idx2name + + data = mx.sym.Variable('data') + sym = mx.sym.FullyConnected(data, num_hidden=20, name='fc') + batch_size = 8 + opt_params = {'learning_rate': 1, 'rescale_grad': 1.0 / batch_size} + + # Pass an optimizer str + mod1 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0)) + mod1.bind(data_shapes=[('data', (batch_size, 20))]) + mod1.init_params() + mod1.init_optimizer(optimizer='sgd', optimizer_params=opt_params) + assert mod1._optimizer.idx2name == get_module_idx2name(mod1) + + # Pass an Optimizer object + mod2 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0)) + mod2.bind(data_shapes=[('data', (batch_size, 20))]) + mod2.init_params() + opt = mx.optimizer.SGD(**opt_params) + mod2.init_optimizer(optimizer=opt) + assert mod2._optimizer.idx2name == get_module_idx2name(mod2) + + if __name__ == '__main__': import nose nose.runmodule()