Skip to content

Commit

Permalink
split_and_load can now handle num_ctx > num_data. Issue apache#13909 (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
mightydeveloper authored and haohuw committed Jun 23, 2019
1 parent 456ca1f commit 5f2f326
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
Return value is a list even if `num_slice` is 1.
"""
size = data.shape[batch_axis]
if size < num_slice:
raise ValueError(
"Too many slices for data with shape %s. Arguments are " \
"num_slice=%d and batch_axis=%d."%(str(data.shape), num_slice, batch_axis))
if even_split and size % num_slice != 0:
raise ValueError(
"data with shape %s cannot be evenly split into %d slices along axis %d. " \
Expand All @@ -75,6 +71,12 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
str(data.shape), num_slice, batch_axis, num_slice))

step = size // num_slice

# If size < num_slice, make fewer slices
if not even_split and size < num_slice:
step = 1
num_slice = size

if batch_axis == 0:
slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
for i in range(num_slice)]
Expand Down

0 comments on commit 5f2f326

Please sign in to comment.