This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add import_ for SymbolBlock #11127
Merged
Merged
add import_ for SymbolBlock #11127
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
305eaca
add import_ for SymbolBlock
piiswrong 3f42282
fix
piiswrong d00f6fe
Update block.py
piiswrong 191b173
add save_parameters
piiswrong 844bf2c
fix
piiswrong 2382018
fix lint
piiswrong 93f067d
fix
piiswrong de4ec6e
fix
piiswrong 9a896ce
fix
piiswrong d5b753f
fix
piiswrong 557a131
fix
piiswrong 5bdf327
Update save_load_params.md
piiswrong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
# under the License. | ||
|
||
# coding: utf-8 | ||
# pylint: disable= arguments-differ | ||
# pylint: disable= arguments-differ, too-many-lines | ||
"""Base container class for all neural network models.""" | ||
__all__ = ['Block', 'HybridBlock', 'SymbolBlock'] | ||
|
||
|
@@ -307,7 +307,7 @@ def _collect_params_with_prefix(self, prefix=''): | |
ret.update(child._collect_params_with_prefix(prefix + name)) | ||
return ret | ||
|
||
def save_params(self, filename): | ||
def save_parameters(self, filename): | ||
"""Save parameters to file. | ||
|
||
filename : str | ||
|
@@ -317,8 +317,23 @@ def save_params(self, filename): | |
arg_dict = {key : val._reduce() for key, val in params.items()} | ||
ndarray.save(filename, arg_dict) | ||
|
||
def load_params(self, filename, ctx=None, allow_missing=False, | ||
ignore_extra=False): | ||
def save_params(self, filename): | ||
"""[Deprecated] Please use save_parameters. | ||
|
||
Save parameters to file. | ||
|
||
filename : str | ||
Path to file. | ||
""" | ||
warnings.warn("save_params is deprecated. Please use save_parameters.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we add something about export? Something like "If you are using an hybridized model and want to serialize it to obtain the network structure and parameters, please refer to HybridBlock.export()" |
||
try: | ||
self.collect_params().save(filename, strip_prefix=self.prefix) | ||
except ValueError as e: | ||
raise ValueError('%s\nsave_params is deprecated. Using ' \ | ||
'save_parameters may resolve this error.'%e.message) | ||
|
||
def load_parameters(self, filename, ctx=None, allow_missing=False, | ||
ignore_extra=False): | ||
"""Load parameters from file. | ||
|
||
filename : str | ||
|
@@ -357,6 +372,25 @@ def load_params(self, filename, ctx=None, allow_missing=False, | |
name, filename, _brief_print_list(self._params.keys()))) | ||
params[name]._load_init(loaded[name], ctx) | ||
|
||
def load_params(self, filename, ctx=None, allow_missing=False, | ||
ignore_extra=False): | ||
"""[Deprecated] Please use load_parameters. | ||
|
||
Load parameters from file. | ||
|
||
filename : str | ||
Path to parameter file. | ||
ctx : Context or list of Context, default cpu() | ||
Context(s) initialize loaded parameters on. | ||
allow_missing : bool, default False | ||
Whether to silently skip loading parameters not represents in the file. | ||
ignore_extra : bool, default False | ||
Whether to silently ignore parameters from the file that are not | ||
present in this Block. | ||
""" | ||
warnings.warn("load_params is deprecated. Please use load_parameters.") | ||
self.load_parameters(filename, ctx, allow_missing, ignore_extra) | ||
|
||
def register_child(self, block, name=None): | ||
"""Registers block as a child of self. :py:class:`Block` s assigned to self as | ||
attributes will be registered automatically.""" | ||
|
@@ -770,8 +804,8 @@ def infer_type(self, *args): | |
self._infer_attrs('infer_type', 'dtype', *args) | ||
|
||
def export(self, path, epoch=0): | ||
"""Export HybridBlock to json format that can be loaded by `mxnet.mod.Module` | ||
or the C++ interface. | ||
"""Export HybridBlock to json format that can be loaded by | ||
`SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface. | ||
|
||
.. note:: When there are only one input, it will have name `data`. When there | ||
Are more than one inputs, they will be named as `data0`, `data1`, etc. | ||
|
@@ -885,6 +919,50 @@ class SymbolBlock(HybridBlock): | |
>>> x = mx.nd.random.normal(shape=(16, 3, 224, 224)) | ||
>>> print(feat_model(x)) | ||
""" | ||
@staticmethod | ||
def imports(symbol_file, input_names, param_file=None, ctx=None): | ||
"""Import model previously saved by `HybridBlock.export` or | ||
`Module.save_checkpoint` as a SymbolBlock for use in Gluon. | ||
|
||
Parameters | ||
---------- | ||
symbol_file : str | ||
Path to symbol file. | ||
input_names : list of str | ||
List of input variable names | ||
param_file : str, optional | ||
Path to parameter file. | ||
ctx : Context, default None | ||
The context to initialize SymbolBlock on. | ||
|
||
Returns | ||
------- | ||
SymbolBlock | ||
SymbolBlock loaded from symbol and parameter files. | ||
|
||
Examples | ||
-------- | ||
>>> net1 = gluon.model_zoo.vision.resnet18_v1( | ||
... prefix='resnet', pretrained=True) | ||
>>> net1.hybridize() | ||
>>> x = mx.nd.random.normal(shape=(1, 3, 32, 32)) | ||
>>> out1 = net1(x) | ||
>>> net1.export('net1', epoch=1) | ||
>>> | ||
>>> net2 = gluon.SymbolBlock.imports( | ||
... 'net1-symbol.json', ['data'], 'net1-0001.params') | ||
>>> out2 = net2(x) | ||
""" | ||
sym = symbol.load(symbol_file) | ||
if isinstance(input_names, str): | ||
input_names = [input_names] | ||
inputs = [symbol.var(i) for i in input_names] | ||
ret = SymbolBlock(sym, inputs) | ||
if param_file is not None: | ||
ret.collect_params().load(param_file, ctx=ctx) | ||
return ret | ||
|
||
|
||
def __init__(self, outputs, inputs, params=None): | ||
super(SymbolBlock, self).__init__(prefix=None, params=None) | ||
self._prefix = '' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L13 there is also "
load_checkpoint
andload
methods" -> "imports
methodThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can make this change as part of another PR to avoid another round of CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry didnt realize this was already in.