Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add TTA transform #2146

Merged
merged 11 commits into from
Oct 19, 2022
6 changes: 3 additions & 3 deletions mmcv/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .loading import LoadAnnotations, LoadImageFromFile
from .processing import (CenterCrop, MultiScaleFlipAug, Normalize, Pad,
RandomChoiceResize, RandomFlip, RandomGrayscale,
RandomResize, Resize)
RandomResize, Resize, TestTimeAug)
from .wrappers import (Compose, KeyMapper, RandomApply, RandomChoice,
TransformBroadcaster)

Expand All @@ -16,7 +16,7 @@
'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations',
'Normalize', 'Resize', 'Pad', 'RandomFlip', 'RandomChoiceResize',
'CenterCrop', 'RandomGrayscale', 'MultiScaleFlipAug', 'RandomResize',
'RandomApply'
'RandomApply', 'TestTimeAug'
]
else:
from .formatting import ImageToTensor, ToTensor, to_tensor
Expand All @@ -26,5 +26,5 @@
'RandomChoice', 'KeyMapper', 'LoadImageFromFile', 'LoadAnnotations',
'Normalize', 'Resize', 'Pad', 'ToTensor', 'to_tensor', 'ImageToTensor',
'RandomFlip', 'RandomChoiceResize', 'CenterCrop', 'RandomGrayscale',
'MultiScaleFlipAug', 'RandomResize', 'RandomApply'
'MultiScaleFlipAug', 'RandomResize', 'RandomApply', 'TestTimeAug'
]
128 changes: 127 additions & 1 deletion mmcv/transforms/processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import random
import warnings
from itertools import product
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union

import mmengine
Expand Down Expand Up @@ -746,7 +748,7 @@ class MultiScaleFlipAug(BaseTransform):
- resize to (1333, 800) + flip

The four results are then transformed with ``transforms`` argument.
After that, results are wrapped into lists of the same length as followed:
After that, results are wrapped into lists of the same length as below:

.. code-block::

Expand Down Expand Up @@ -870,6 +872,130 @@ def __repr__(self) -> str:
return repr_str


@TRANSFORMS.register_module()
class TestTimeAug(BaseTransform):
"""Test-time augmentation transform.

An example configuration is as followed:

.. code-block::

dict(type='TestTimeAug',
transforms=[
[dict(type='Resize', scale=(1333, 800), keep_ratio=True),
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
dict(type='Resize', scale=(1333, 800), keep_ratio=True)],
[dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)],
[dict(type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))]])

``results`` will be transformed using all transforms defined in
``transforms`` arguments.

For the above configuration, there are four combinations of resize
and flip:

- Resize to (1333, 400) + no flip
- Resize to (1333, 400) + flip
- Resize to (1333, 800) + no flip
- resize to (1333, 800) + flip

After that, results are wrapped into lists of the same length as followed:
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved

.. code-block::

dict(
inputs=[...],
data_samples=[...]
)

The length of ``inputs`` and ``data_samples`` are both 4.

Required Keys:

- Depending on the requirements of the ``transforms`` parameter.

Modified Keys:

- All output keys of each transform.

Args:
transforms (list[list[dict]]): Transforms to be applied to data sampled
from dataset. ``transforms`` is a list of list, and each list
element usually represents a series of transforms with the same
type and different arguments. Data will be processed by each list
elements sequentially. See more information in :meth:`transform`.
"""

def __init__(self, transforms: list):
for i, transform_list in enumerate(transforms):
for j, transform in enumerate(transform_list):
if isinstance(transform, dict):
transform_list[j] = TRANSFORMS.build(transform)
elif callable(transform):
continue
else:
raise TypeError(
'transform must be callable or a dict, but got'
f' {type(transform)}')
transforms[i] = transform_list

self.subroutines = [
Compose(subroutine) for subroutine in product(*transforms)
]

def transform(self, results: dict) -> dict:
"""Apply all transforms defined in :attr:`transforms` to the results.

As the example given in :obj:`TestTimeAug`, ``transforms`` consists of
2 ``Resize``, 2 ``RandomFlip`` and 1 ``PackDetInputs``.
The data sampled from dataset will be processed as follows:

1. Data will be processed by 2 ``Resize`` and return a list
of 2 results.
2. Each result in list will be further passed to 2
``RandomFlip``, and aggregates into a list of 4 results.
3. Each result will be processed by ``PackDetInputs``, and
return a list of dict.
4. Aggregates the same fields of results, and finally return
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
a dict. Each value of the dict represents 4 transformed
results.

Args:
results (dict): Result dict contains the data to transform.

Returns:
dict: The augmented data, where each value is wrapped
into a list.
"""
results_list = [] # type: ignore
for subroutine in self.subroutines:
result = subroutine(copy.deepcopy(results))
assert isinstance(result, dict), (
f'Data processed by {subroutine} must return a dict, but got '
f'{result}')
assert result is not None, (
f'Data processed by {subroutine} in `TestTimeAug` should not '
'be None! Please check your validation dataset and the '
f'transforms in {subroutine}')
results_list.append(result)

aug_data_dict = {
key: [item[key] for item in results_list] # type: ignore
for key in results_list[0] # type: ignore
}
return aug_data_dict

def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += 'transforms=\n'
for subroutine in self.subroutines:
repr_str += f'{repr(subroutine)}\n'
return repr_str


@TRANSFORMS.register_module()
class RandomChoiceResize(BaseTransform):
"""Resize images & bbox & mask from a list of multiple scales.
Expand Down
88 changes: 87 additions & 1 deletion tests/test_transforms/test_transforms_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import mmcv
from mmcv.transforms import (TRANSFORMS, Normalize, Pad, RandomFlip,
RandomResize, Resize)
RandomResize, Resize, TestTimeAug)
from mmcv.transforms.base import BaseTransform

try:
Expand Down Expand Up @@ -900,3 +900,89 @@ def test_transform(self):
resize_type='Resize',
keep_ratio=True)
results_update = TRANSFORMS.transform(copy.deepcopy(results))


class TestTestTimeAug:

def test_init(self):
subroutines = [[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
dict(type='Resize', scale=(1333, 800), keep_ratio=True)
], [
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
], [dict(type='Normalize', mean=(0, 0, 0), std=(1, 1, 1))]]

tta_transform = TestTimeAug(subroutines)
subroutines = tta_transform.subroutines
assert len(subroutines) == 4

assert isinstance(subroutines[0].transforms[0], Resize)
assert isinstance(subroutines[0].transforms[1], RandomFlip)
assert isinstance(subroutines[0].transforms[2], Normalize)
assert isinstance(subroutines[1].transforms[0], Resize)
assert isinstance(subroutines[1].transforms[1], RandomFlip)
assert isinstance(subroutines[1].transforms[2], Normalize)

def test_transform(self):
results = {
'img': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 1, 100, 101]]),
'gt_keypoints': np.array([[[100, 100, 1.0]]]),
'gt_seg_map': np.random.random((224, 224, 3))
}
input_results = copy.deepcopy(results)
transforms = [[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='Resize', scale=(1333, 640), keep_ratio=True)
], [
dict(type='RandomFlip', prob=0.),
dict(type='RandomFlip', prob=1.)
], [dict(type='Normalize', mean=(0, 0, 0), std=(1, 1, 1))]]

tta_transform = TestTimeAug(transforms)
results = tta_transform.transform(results)
assert len(results['img']) == 4

resize1 = tta_transform.subroutines[0].transforms[0]
resize2 = tta_transform.subroutines[2].transforms[0]
flip1 = tta_transform.subroutines[0].transforms[1]
flip2 = tta_transform.subroutines[1].transforms[1]
normalize = tta_transform.subroutines[0].transforms[2]
target_results = [
normalize.transform(
flip1.transform(
resize1.transform(copy.deepcopy(input_results)))),
normalize.transform(
flip2.transform(
resize1.transform(copy.deepcopy(input_results)))),
normalize.transform(
flip1.transform(
resize2.transform(copy.deepcopy(input_results)))),
normalize.transform(
flip2.transform(
resize2.transform(copy.deepcopy(input_results)))),
]

assert np.allclose(target_results[0]['img'], results['img'][0])
assert np.allclose(target_results[1]['img'], results['img'][1])
assert np.allclose(target_results[2]['img'], results['img'][2])
assert np.allclose(target_results[3]['img'], results['img'][3])

def test_repr(self):
transforms = [[
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='Resize', scale=(1333, 640), keep_ratio=True)
], [
dict(type='RandomFlip', prob=0.),
dict(type='RandomFlip', prob=1.)
], [dict(type='Normalize', mean=(0, 0, 0), std=(1, 1, 1))]]

tta_transform = TestTimeAug(transforms)
repr_str = repr(tta_transform)
repr_str_list = repr_str.split('\n')
assert repr_str_list[0] == 'TestTimeAugtransforms='
assert repr_str_list[1] == 'Compose('
assert repr_str_list[2].startswith(' Resize(scale=(1333, 800)')
assert repr_str_list[3].startswith(' RandomFlip(prob=0.0')
assert repr_str_list[4].startswith(' Normalize(mean=[0. 0. 0.]')