diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 5af2b9556651..c64d3763d59f 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1057,6 +1057,8 @@ class ImageIter(io.DataIter): Data name for provided symbols. label_name : str Label name for provided symbols. + dtype : str + Label data type. Default: float32. Other options: int32, int64, float64 kwargs : ... More arguments for creating augmenter. See mx.image.CreateAugmenter. """ @@ -1064,7 +1066,7 @@ class ImageIter(io.DataIter): def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, - data_name='data', label_name='softmax_label', **kwargs): + data_name='data', label_name='softmax_label', dtype='float32', **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist or (isinstance(imglist, list)) num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1) @@ -1091,7 +1093,7 @@ def __init__(self, batch_size, data_shape, label_width=1, imgkeys = [] for line in iter(fin.readline, ''): line = line.strip().split('\t') - label = nd.array([float(i) for i in line[1:-1]]) + label = nd.array([i for i in line[1:-1]], dtype=dtype) key = int(line[0]) imglist[key] = (label, line[-1]) imgkeys.append(key) @@ -1105,11 +1107,11 @@ def __init__(self, batch_size, data_shape, label_width=1, key = str(index) # pylint: disable=redefined-variable-type index += 1 if len(img) > 2: - label = nd.array(img[:-1]) + label = nd.array(img[:-1], dtype=dtype) elif isinstance(img[0], numeric_types): - label = nd.array([img[0]]) + label = nd.array([img[0]], dtype=dtype) else: - label = nd.array(img[0]) + label = nd.array(img[0], dtype=dtype) result[key] = (label, img[-1]) imgkeys.append(str(key)) self.imglist = result diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 636c5e2be67c..80ebcc6265db 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -132,26 +132,30 @@ def test_color_normalize(self): def test_imageiter(self): - im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES] - test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list, - path_root='') - for _ in range(3): + def check_ImageIter(dtype='float32'): + im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES] + test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list, + path_root='', dtype=dtype) + for _ in range(3): + for batch in test_iter: + pass + test_iter.reset() + + # test with list file + fname = './data/test_imageiter.lst' + file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \ + for k, x in enumerate(TestImage.IMAGES)] + with open(fname, 'w') as f: + for line in file_list: + f.write(line + '\n') + + test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname, + path_root='', dtype=dtype) for batch in test_iter: pass - test_iter.reset() - - # test with list file - fname = './data/test_imageiter.lst' - file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \ - for k, x in enumerate(TestImage.IMAGES)] - with open(fname, 'w') as f: - for line in file_list: - f.write(line + '\n') - test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname, - path_root='') - for batch in test_iter: - pass + for dtype in ['int32', 'float32', 'int64', 'float64']: + check_ImageIter(dtype) @with_seed() def test_augmenters(self):