Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Mar 20, 2023
1 parent a384964 commit b6834bd
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tests/test_ops/test_roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE

_USING_PARROTS = True
try:
Expand Down Expand Up @@ -102,7 +102,11 @@ def _test_roialign_allclose(device, dtype):
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [
torch.float,
Expand All @@ -111,6 +115,11 @@ def _test_roialign_allclose(device, dtype):
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_NPU_AVAILABLE,
reason='NPU does not support for 64-bit floating point')),
torch.half
])
def test_roialign(device, dtype):
Expand Down

0 comments on commit b6834bd

Please sign in to comment.