Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add repr for SymbolBlock #14423

Merged
merged 4 commits into from
Mar 15, 2019
Merged

Add repr for SymbolBlock #14423

merged 4 commits into from
Mar 15, 2019

Conversation

vandanavk
Copy link
Contributor

@vandanavk vandanavk commented Mar 13, 2019

Description

SymbolBlock uses Block.__repr__ when print(SymbolBlock) is executed. In Block.__repr__, data is printed only if it is of type Block. Adding a __repr__ to SymbolBlock class that prints the output symbol. This behavior is similar to printing pure Symbol API.

Fixes #13616

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add a repr for SymbolBlock
    Executing
import mxnet as mx
from mxnet.gluon.model_zoo import vision

alexnet = vision.alexnet(pretrained=True)
alexnet.initialize()
alexnet.hybridize()
alexnet(mx.random.uniform(shape=(1,3,224,224)))
alexnet.export("./alexnet")

block = mx.gluon.nn.SymbolBlock.imports('alexnet-symbol.json', ['data'], 'alexnet-0000.params')
print(block)

gives the output
Before

SymbolBlock(

)

After change

SymbolBlock(
<Symbol alexnet0_dense2_fwd> : 1 -> 1
)

Example with a model with more than one output:

import mxnet as mx
from mxnet import gluon

data = mx.sym.Variable('data')
topk = mx.sym.topk(data, k=3, ret_typ='both')
ctx = mx.cpu()

pre_trained = gluon.nn.SymbolBlock(outputs=topk, inputs=data)
print (pre_trained)

Output:

SymbolBlock(
<Symbol topk2> : 1 -> 2
)

Comments

@eric-haibin-lin @Ishitori

@vandanavk vandanavk requested a review from szha as a code owner March 13, 2019 22:47
@vandanavk
Copy link
Contributor Author

@mxnet-label-bot add [Gluon, pr-awaiting-review]

@marcoabreu marcoabreu added Gluon pr-awaiting-review PR is waiting for code review labels Mar 13, 2019
Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test? Also, what does the current string representation look like, for example for alexnet?

@eric-haibin-lin
Copy link
Member

+1 on szha's comment

@vandanavk
Copy link
Contributor Author

Thanks for inputs @szha @eric-haibin-lin. Added an example in the PR description. Will add a test and submit shortly

@vandanavk
Copy link
Contributor Author

@szha @eric-haibin-lin added a test in test_import() of test_gluon.py

@@ -1024,6 +1024,11 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
ret.collect_params().load(param_file, ctx=ctx)
return ret

def __repr__(self):
s = '{name}(\n{modstr}\n)'
modstr = '\n'.join(['{block}'.format(block=self.__dict__['_cached_graph'][-1])])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use __dict__? or -1 index?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be more useful to show also the input count and output count?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

corrected dict to self._cached_graph.
-1 index as the last element of cached_graph is the output. I could have used cached_graph[1], but did not want to hit index out of bounds (in case it ever happens).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think showing all layers like how HybridSequential prints, would be most useful, but not sure how to do that for Symbol API and SymbolBlock models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Printing the inner structure might be too complicated and also not readable, thus I suggested printing the count.

Why would cached_graph[1] ever hit index out of bounds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this?

SymbolBlock(
<Symbol dense1_fwd>
)
Number of inputs: 1
Number of outputs: 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<Symbol dense1_fwd>: 1->1 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szha thanks for the inputs. submitted after addressing the comments. Added an example of multiple outputs in the PR description

Copy link
Member

@wkcn wkcn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! LGTM
I think the test case Is too strict. We can use 'in' to test it.

@vandanavk
Copy link
Contributor Author

@wkcn Do you mean lines[0] check? I did something similar to https://github.com/apache/incubator-mxnet/blob/a6a4fe188d4847a608952fb69ff68e7ab91c5050/tests/python/unittest/test_gluon.py#L237

@szha szha merged commit 226212b into apache:master Mar 15, 2019
vdantu pushed a commit to vdantu/incubator-mxnet that referenced this pull request Mar 31, 2019
* Add repr for SymbolBlock

* Add a test

* Correct self.cached_graph

* Address review comments
nswamy pushed a commit that referenced this pull request Apr 5, 2019
* Add repr for SymbolBlock

* Add a test

* Correct self.cached_graph

* Address review comments
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* Add repr for SymbolBlock

* Add a test

* Correct self.cached_graph

* Address review comments
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Gluon pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SymbolBlock doesn't print anything
5 participants