-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add dtype visualization to plot_network #14066
Conversation
python/mxnet/visualization.py
Outdated
shape_dict = dict(zip(interals.list_outputs(), out_shapes)) | ||
shape_dict = dict(zip(internals.list_outputs(), out_shapes)) | ||
draw_type = False | ||
if dtype is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just if dtype:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as shapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this seems slightly more readable.
draw_type = dtype is not None
if draw_type:
...
@mxnet-label-bot add [pr-awaiting-review, Visualization] |
@@ -55,6 +56,7 @@ def test_plot_network(): | |||
net = mx.sym.SoftmaxOutput(data=net, name='out') | |||
with warnings.catch_warnings(record=True) as w: | |||
digraph = mx.viz.plot_network(net, shape={'data': (100, 200)}, | |||
dtype={'data': np.float32}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be testing for dtypes other than np.float32
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not really matter - this test just tries to catch errors during preparation of the picture, and for that the exact type used does not make any difference.
@szha Does anything else need to be done with this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Merged. Thanks for your contribution! |
* Add dtype to plot_network * Added docstring for the new param * Added dtype to the plot_network test * Changes from review * Fixes from review * Fix typo * Retrigger CI
* Add dtype to plot_network * Added docstring for the new param * Added dtype to the plot_network test * Changes from review * Fixes from review * Fix typo * Retrigger CI
* Add dtype to plot_network * Added docstring for the new param * Added dtype to the plot_network test * Changes from review * Fixes from review * Fix typo * Retrigger CI
Description
Add possibility to print type information alongside shape in
mx.vis.plot_network
.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.