Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix benchmark on control flow operators. (#12693)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored and eric-haibin-lin committed Oct 8, 2018
1 parent 7e46b5e commit 077253d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions benchmark/python/control_flow/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True)
_parser.add_argument('--warmup_rounds', type=int, default=20)
_parser.add_argument('--test_rounds', type=int, default=100)
_parser.add_argument('--gpu', type=bool, default=False)
args = _parser.parse_args()


Expand Down Expand Up @@ -66,8 +67,7 @@ def _func(*states):
loop_vars=states,
max_iterations=self.length,
)
assert len(out) == 1
return out[0]
return out


def _zeros(shape, ctx):
Expand Down Expand Up @@ -124,7 +124,9 @@ def main():
cell_types = [gluon.rnn.RNNCell,
gluon.rnn.GRUCell,
gluon.rnn.LSTMCell]
ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()]
ctxs = [mx.cpu(0)]
if args.gpu:
ctxs = ctxs + [mx.gpu(i) for i in _get_gpus()]
seq_lens = [100]
batch_sizes = [1, 32]
hidden_dims = [512]
Expand Down

0 comments on commit 077253d

Please sign in to comment.