Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update image record iterator tests to check the whole iterator not on…
Browse files Browse the repository at this point in the history
…ly first image
  • Loading branch information
perdasilva authored and Per Goncalves da Silva committed Dec 7, 2018
1 parent fd1e421 commit 39117c5
Showing 1 changed file with 50 additions and 21 deletions.
71 changes: 50 additions & 21 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
import sys
from common import assertRaises
import unittest
try:
from itertools import izip_longest as zip_longest
except:
from itertools import zip_longest


def test_MNISTIter():
Expand Down Expand Up @@ -427,13 +431,46 @@ def check_CSVIter_synthetic(dtype='float32'):
for dtype in ['int32', 'int64', 'float32']:
check_CSVIter_synthetic(dtype=dtype)

# @unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11359")
def test_ImageRecordIter_seed_augmentation():
get_cifar10()
seed_aug = 3

def assert_dataiter_equals(dataiter1, dataiter2):
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# ensure iterators contain the same number of batches
# zip_longest will return None if on of the iterators have run out of batches
assert batch1 and batch2, 'The iterators do not contain the same number of batches'

# ensure batches are of same length
assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length'

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
assert(np.array_equal(data1, data2))

def assert_dataiter_not_equals(dataiter1, dataiter2):
for batch1, batch2 in zip_longest(dataiter1, dataiter2):

# ensure iterators are of same length
# zip_longest will return None if on of the iterators have run out of batches
assert batch1 and batch2, 'The iterators do not contain the same number of batches'

# ensure batches are of same length
assert len(batch1.data) == len(batch2.data), 'The returned batches are not of the same length'

# ensure batch data is the same
for i in range(0, len(batch1.data)):
data1 = batch1.data[i].asnumpy().astype(np.uint8)
data2 = batch2.data[i].asnumpy().astype(np.uint8)
if not np.array_equal(data1, data2):
return
assert False, 'Expected data iterators to be different, but they are the same'

# check whether to get constant images after fixing seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -449,11 +486,8 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -469,12 +503,12 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_equals(dataiter1, dataiter2)

# check whether to get different images after change seed_aug
dataiter = mx.io.ImageRecordIter(
dataiter1.reset()
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
Expand All @@ -490,32 +524,27 @@ def test_ImageRecordIter_seed_augmentation():
random_h=10,
max_shear_ratio=2,
seed_aug=seed_aug+1)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(not np.array_equal(data,data2))

assert_dataiter_not_equals(dataiter1, dataiter2)

# check whether seed_aug changes the iterator behavior
dataiter = mx.io.ImageRecordIter(
dataiter1 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
test_index = rnd.randint(0, len(batch.data))
data = batch.data[test_index].asnumpy().astype(np.uint8)

dataiter = mx.io.ImageRecordIter(
dataiter2 = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar10_mean.bin",
shuffle=False,
data_shape=(3, 28, 28),
batch_size=3,
seed_aug=seed_aug)
batch = dataiter.next()
data2 = batch.data[test_index].asnumpy().astype(np.uint8)
assert(np.array_equal(data,data2))

assert_dataiter_equals(dataiter1, dataiter2)

if __name__ == "__main__":
test_NDArrayIter()
Expand Down

0 comments on commit 39117c5

Please sign in to comment.