Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Support integer type in ImageIter
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Jul 24, 2018
1 parent 38282e9 commit ad60bd3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
12 changes: 7 additions & 5 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,14 +1057,16 @@ class ImageIter(io.DataIter):
Data name for provided symbols.
label_name : str
Label name for provided symbols.
dtype : str
Label data type. Default: float32. Other options: int32, int64, float64
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateAugmenter.
"""

def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='softmax_label', **kwargs):
data_name='data', label_name='softmax_label', dtype='float32', **kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist or (isinstance(imglist, list))
num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1)
Expand All @@ -1091,7 +1093,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
imgkeys = []
for line in iter(fin.readline, ''):
line = line.strip().split('\t')
label = nd.array([float(i) for i in line[1:-1]])
label = nd.array([i for i in line[1:-1]], dtype=dtype)
key = int(line[0])
imglist[key] = (label, line[-1])
imgkeys.append(key)
Expand All @@ -1105,11 +1107,11 @@ def __init__(self, batch_size, data_shape, label_width=1,
key = str(index) # pylint: disable=redefined-variable-type
index += 1
if len(img) > 2:
label = nd.array(img[:-1])
label = nd.array(img[:-1], dtype=dtype)
elif isinstance(img[0], numeric_types):
label = nd.array([img[0]])
label = nd.array([img[0]], dtype=dtype)
else:
label = nd.array(img[0])
label = nd.array(img[0], dtype=dtype)
result[key] = (label, img[-1])
imgkeys.append(str(key))
self.imglist = result
Expand Down
38 changes: 21 additions & 17 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,26 +132,30 @@ def test_color_normalize(self):


def test_imageiter(self):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
path_root='')
for _ in range(3):
def check_ImageIter(dtype='float32'):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
path_root='', dtype=dtype)
for _ in range(3):
for batch in test_iter:
pass
test_iter.reset()

# test with list file
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
path_root='', dtype=dtype)
for batch in test_iter:
pass
test_iter.reset()

# test with list file
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
path_root='')
for batch in test_iter:
pass
for dtype in ['int32', 'float32', 'int64', 'float64']:
check_ImageIter(dtype)

@with_seed()
def test_augmenters(self):
Expand Down

0 comments on commit ad60bd3

Please sign in to comment.