Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RetinaNet Object detection with Backbones #529

Merged
merged 150 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
a665049
refactor frcnn
oke-aditya Jan 19, 2021
7827b44
start adding retina
oke-aditya Jan 19, 2021
7a4955a
add test
oke-aditya Jan 19, 2021
2feca3b
complete Retinanet
oke-aditya Jan 20, 2021
c4e25bd
Isort test
oke-aditya Jan 20, 2021
f912d9e
reformat
oke-aditya Jan 20, 2021
00322cd
update requirments
oke-aditya Jan 20, 2021
6513b74
bump torch
oke-aditya Jan 20, 2021
61aafbf
update
oke-aditya Jan 20, 2021
5b8c65b
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
oke-aditya Jan 20, 2021
c476492
Merge branch 'master' of github.com:PyTorchLightning/pytorch-lightnin…
oke-aditya Jan 20, 2021
a8fda1a
slight documentation edit
oke-aditya Jan 20, 2021
e4803a2
sppedup test
oke-aditya Jan 20, 2021
2a53b90
Apply suggestions from code review
oke-aditya Jan 23, 2021
38e8a2e
remove max epochs
oke-aditya Jan 23, 2021
5bee954
fmt
oke-aditya Jan 23, 2021
6f5e5b4
Apply suggestions from code review
oke-aditya Jan 26, 2021
2542974
fix for cxonsist
oke-aditya Jan 26, 2021
024ef6e
Merge branch 'master' of github.com:PyTorchLightning/pytorch-lightnin…
oke-aditya Jan 30, 2021
9ea1792
Merge branch 'master' into add_retina
oke-aditya Feb 2, 2021
d24643e
Apply suggestions from code review
oke-aditya Feb 3, 2021
66d068c
changes to self.log
oke-aditya Feb 4, 2021
4fc34a0
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
oke-aditya Feb 4, 2021
62de8e7
Merge commit
oke-aditya Feb 16, 2021
d997cdc
add changes
oke-aditya Feb 16, 2021
6b07349
Merge branch 'master' of github.com:PyTorchLightning/pytorch-lightnin…
oke-aditya Feb 22, 2021
b59d340
refactor frcnn
oke-aditya Jan 19, 2021
4df8eab
start adding retina
oke-aditya Jan 19, 2021
d508792
add test
oke-aditya Jan 19, 2021
2ae94f2
complete Retinanet
oke-aditya Jan 20, 2021
ffdddbb
Isort test
oke-aditya Jan 20, 2021
5c4b695
reformat
oke-aditya Jan 20, 2021
91b9d9b
update requirments
oke-aditya Jan 20, 2021
7e5b800
slight documentation edit
oke-aditya Jan 20, 2021
f558a39
sppedup test
oke-aditya Jan 20, 2021
7229dfb
Apply suggestions from code review
oke-aditya Jan 23, 2021
b627fb8
remove max epochs
oke-aditya Jan 23, 2021
085c295
fmt
oke-aditya Jan 23, 2021
9a3623e
Apply suggestions from code review
oke-aditya Jan 26, 2021
7cb0624
fix for cxonsist
oke-aditya Jan 26, 2021
198fdf5
changes to self.log
oke-aditya Feb 4, 2021
b0f7ce5
Apply suggestions from code review
oke-aditya Feb 3, 2021
e6e776d
add changes
oke-aditya Feb 16, 2021
6a106b2
Merge branch 'master' into add_retina
mergify[bot] May 9, 2021
b859f58
Merge branch 'master' into add_retina
mergify[bot] May 9, 2021
af1ea57
Merge branch 'master' into add_retina
mergify[bot] May 9, 2021
8db6a41
Merge branch 'master' into add_retina
mergify[bot] May 10, 2021
f95d840
Merge branch 'master' into add_retina
mergify[bot] May 10, 2021
c7fa861
Merge branch 'master' into add_retina
mergify[bot] May 11, 2021
ae2b186
Merge branch 'master' into add_retina
mergify[bot] May 11, 2021
2bf742d
Merge branch 'master' into add_retina
mergify[bot] May 11, 2021
8289685
Merge branch 'master' into add_retina
mergify[bot] May 14, 2021
c96e52d
Merge branch 'master' into add_retina
mergify[bot] May 17, 2021
2b0dc75
Merge branch 'master' into add_retina
mergify[bot] Jun 15, 2021
3b05001
Merge branch 'master' into add_retina
mergify[bot] Jun 15, 2021
95befd3
Merge branch 'master' into add_retina
mergify[bot] Jun 15, 2021
c6de38d
Merge branch 'master' into add_retina
mergify[bot] Jun 16, 2021
fbf15a8
Merge branch 'master' into add_retina
mergify[bot] Jun 16, 2021
b8cf695
Merge branch 'master' into add_retina
mergify[bot] Jun 16, 2021
5284931
Merge branch 'master' into add_retina
mergify[bot] Jun 16, 2021
a53449a
v0.3.4 & changelog
Borda Jun 17, 2021
bc52bda
Merge branch 'master' into add_retina
mergify[bot] Jun 17, 2021
5fab585
Merge branch 'master' into add_retina
mergify[bot] Jun 17, 2021
fcc77f2
Merge branch 'master' into add_retina
mergify[bot] Jun 21, 2021
8e663ee
Merge branch 'master' into add_retina
mergify[bot] Jun 21, 2021
45f3a17
Merge branch 'master' into add_retina
mergify[bot] Jun 24, 2021
147cc82
Merge branch 'master' of github.com:PyTorchLightning/pytorch-lightnin…
oke-aditya Aug 26, 2021
e732761
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
oke-aditya Aug 26, 2021
550094a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
786734b
Fix
oke-aditya Aug 26, 2021
a194fb6
whatever precommit says
oke-aditya Aug 26, 2021
bd9c409
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2021
6d7c0ac
bump torch for CI
oke-aditya Aug 26, 2021
7dd0974
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
oke-aditya Aug 26, 2021
b0deee9
Merge branch 'master' into add_retina
mergify[bot] Aug 27, 2021
4326a5f
Merge branch 'master' into add_retina
mergify[bot] Aug 27, 2021
3f2d69c
Merge branch 'master' of github.com:PyTorchLightning/pytorch-lightnin…
oke-aditya Sep 13, 2021
5f46b30
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
oke-aditya Sep 13, 2021
f215454
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
1356803
Merge branch 'master' into add_retina
mergify[bot] Sep 23, 2021
10a4fde
Update pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py
oke-aditya Sep 24, 2021
857a19f
Apply suggestions from code review
oke-aditya Sep 24, 2021
9817ced
Merge branch 'master' into add_retina
mergify[bot] Sep 30, 2021
7488e14
Merge branch 'master' into add_retina
mergify[bot] Oct 14, 2021
3f16542
Merge branch 'master' into add_retina
mergify[bot] Oct 14, 2021
6a8d8ec
Merge branch 'master' into add_retina
mergify[bot] Oct 14, 2021
7b4ac77
Merge branch 'master' into add_retina
mergify[bot] Oct 15, 2021
29bb7e4
Merge branch 'master' into add_retina
mergify[bot] Oct 20, 2021
cf9d8c2
Merge branch 'master' into add_retina
mergify[bot] Nov 8, 2021
91e67bd
Merge branch 'master' into add_retina
mergify[bot] Nov 8, 2021
273e5a6
Merge branch 'master' into add_retina
mergify[bot] Nov 8, 2021
73b687d
Merge branch 'master' into add_retina
mergify[bot] Nov 8, 2021
ad92af2
Merge branch 'master' into add_retina
mergify[bot] Nov 15, 2021
65f5c88
Merge branch 'master' into add_retina
Borda Nov 18, 2021
5787806
Merge branch 'master' into add_retina
mergify[bot] Nov 26, 2021
93f0f8e
Merge branch 'master' into add_retina
mergify[bot] Nov 26, 2021
a6cda52
Merge branch 'master' into add_retina
mergify[bot] Nov 26, 2021
ecf5376
Fix hparams.encoder forgotten rename in CPCv2 (#773)
praecipue Nov 26, 2021
93fdcb2
Merge branch 'master' into add_retina
mergify[bot] Nov 26, 2021
75d649f
Merge branch 'master' into add_retina
mergify[bot] Nov 29, 2021
cd3a403
Empty commit to rerun CI
akihironitta Dec 17, 2021
39b4a5d
Rename to retinanet
akihironitta Dec 18, 2021
b089692
ci: quick fix - uninstall torchtext
akihironitta Dec 18, 2021
321bfc9
Update torch version from 1.6 to 1.8
akihironitta Dec 18, 2021
d045844
Merge branch 'ci/update-gpu' into add_retina
akihironitta Dec 18, 2021
8fde41e
Revert "ci: quick fix - uninstall torchtext"
akihironitta Dec 18, 2021
1f69af4
Merge branch 'master' into add_retina
mergify[bot] Dec 18, 2021
5a8035f
Update CHANGELOG
akihironitta Dec 18, 2021
7a15e79
Include RetinaNet in the docs
akihironitta Dec 18, 2021
b8f4572
Fix RetinaNet docstring
akihironitta Dec 18, 2021
c9d613f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2021
6418168
Fix cli_main
akihironitta Dec 18, 2021
a5a01af
Try replace iou with torchvision's
akihironitta Dec 18, 2021
36d01b7
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
akihironitta Dec 18, 2021
60e21b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2021
9e0c175
Fix RetinaNet docstring
akihironitta Dec 18, 2021
7050f83
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
akihironitta Dec 18, 2021
d738dd9
protect torchvision imports
akihironitta Dec 18, 2021
36d15d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2021
ba5d09f
Revert "add changes"
akihironitta Dec 18, 2021
251429a
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
akihironitta Dec 18, 2021
f5aa892
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2021
4661f68
Fix revert
akihironitta Dec 18, 2021
0ea1125
Update validation_epoch_end
akihironitta Dec 18, 2021
5624829
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
akihironitta Dec 18, 2021
e235c68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2021
3dbf3e8
Revert "Try replace iou with torchvision's"
akihironitta Dec 18, 2021
620daf5
Undo unrelated changes
akihironitta Dec 18, 2021
4739208
Make iou evalution self-contained
akihironitta Dec 18, 2021
f4095fb
Undo unrelated changes again
akihironitta Dec 18, 2021
b8da07b
Undo unrelated changes again
akihironitta Dec 18, 2021
e45d45c
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
akihironitta Dec 18, 2021
c30e68e
Fix docstring for doctest
akihironitta Dec 18, 2021
1e0b268
Set pytorch-lightning>=1.4.8
akihironitta Dec 18, 2021
5fccad6
Merge branch 'ci/pl-1.4.8' into add_retina
akihironitta Dec 18, 2021
0af474e
Revert "Merge branch 'ci/pl-1.4.8' into add_retina"
akihironitta Dec 19, 2021
2579bf4
Merge branch 'master' into add_retina
mergify[bot] Dec 19, 2021
3be86f2
Set pytorch-lightning>=1.4.0
akihironitta Dec 19, 2021
b235c35
Use LightningCLI
akihironitta Dec 19, 2021
2e351df
Add LightningCLI requirement
akihironitta Dec 19, 2021
0a7a877
Use LightningCLI in tests
akihironitta Dec 19, 2021
7176d5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2021
6ba5b86
Follow flake8
akihironitta Dec 19, 2021
255c192
Use gpus
akihironitta Dec 19, 2021
afd0a85
Merge branch 'add_retina' of github.com:oke-aditya/pytorch-lightning-…
akihironitta Dec 19, 2021
8182964
Run only on gpu env
akihironitta Dec 19, 2021
8e15fd0
Use LightningCLI v2
akihironitta Dec 19, 2021
6baa9c7
Add LightningCLI requirement
akihironitta Dec 20, 2021
48326ee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 20, 2021
00aa03d
Merge branch 'master' into add_retina
Borda Dec 20, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pl_bolts/metrics/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import torch


def _evaluate_iou(preds: torch.Tensor, target: torch.Tensor):
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""

if preds["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=preds["boxes"].device)
return iou(target["boxes"], preds["boxes"]).diag().mean()


def iou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
"""
Calculates the intersection over union.
Expand Down Expand Up @@ -58,6 +70,12 @@ def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target,
where N is the number of prediction bounding boxes and M is the number of target bounding boxes
"""

# degenerate boxes gives inf / nan results
# so do an early check
assert (preds[:, 2:] >= preds[:, :2]).all()
assert (target[:, 2:] >= target[:, :2]).all()
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved

x_min = torch.max(preds[:, None, 0], target[:, 0])
y_min = torch.max(preds[:, None, 1], target[:, 1])
x_max = torch.min(preds[:, None, 2], target[:, 2])
Expand Down
6 changes: 2 additions & 4 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401
from pl_bolts.models.detection.retinanet import RetinaNet # noqa: F401
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"components",
"FasterRCNN",
]
__all__ = ["components", "FasterRCNN", "RetinaNet"]
20 changes: 3 additions & 17 deletions pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,18 @@
import pytorch_lightning as pl
import torch

from pl_bolts.metrics.object_detection import _evaluate_iou
from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FastRCNNPredictor
from torchvision.ops import box_iou
else: # pragma: no cover
warn_missing_pkg("torchvision")


def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


class FasterRCNN(pl.LightningModule):
"""
PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with
Expand All @@ -47,7 +33,7 @@ class FasterRCNN(pl.LightningModule):
CLI command::

# PascalVOC
python faster_rcnn.py --gpus 1 --pretrained True
python faster_rcnn_module.py --gpus 1 --pretrained True
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand Down Expand Up @@ -117,7 +103,7 @@ def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
iou = torch.stack([_evaluate_iou(o, t) for t, o in zip(targets, outs)]).mean()
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
return {"val_iou": iou}

def validation_epoch_end(self, outs):
Expand Down
4 changes: 4 additions & 0 deletions pl_bolts/models/detection/retinanet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone
from pl_bolts.models.detection.retinanet.retainanet_module import RetinaNet

__all__ = ["create_retinanet_backbone", "RetinaNet"]
41 changes: 41 additions & 0 deletions pl_bolts/models/detection/retinanet/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any, Optional

import torch.nn as nn

from pl_bolts.models.detection.components import create_torchvision_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
else: # pragma: no cover
warn_missing_pkg("torchvision")


def create_retinanet_backbone(
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
backbone: str,
fpn: bool = True,
pretrained: Optional[str] = None,
trainable_backbone_layers: int = 3,
**kwargs: Any
) -> nn.Module:
"""
Args:
backbone:
Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152",
"resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2",
as resnets with fpn backbones.
Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101",
"resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19",
fpn: If True then constructs fpn as well.
pretrained: If None creates imagenet weights backbone.
trainable_backbone_layers: number of trainable resnet layers starting from final block.
"""

if fpn:
# Creates a torchvision resnet model with fpn added.
backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs)
else:
# This does not create fpn backbone, it is supported for all models
backbone, _ = create_torchvision_backbone(backbone, pretrained)
return backbone
148 changes: 148 additions & 0 deletions pl_bolts/models/detection/retinanet/retainanet_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from argparse import ArgumentParser
from typing import Any, Optional

import pytorch_lightning as pl
import torch

from pl_bolts.metrics.object_detection import _evaluate_iou
from pl_bolts.models.detection.retinanet import create_retinanet_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.retinanet import RetinaNet as torchvision_RetinaNet
from torchvision.models.detection.retinanet import retinanet_resnet50_fpn, RetinaNetHead
else: # pragma: no cover
warn_missing_pkg("torchvision")


class RetinaNet(pl.LightningModule):
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
"""
PyTorch Lightning implementation of Retina Net `Focal Loss for
Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.

Paper authors: Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár

Model implemented by:
- `Aditya Oke <https://github.com/oke-aditya>`

During training, the model expects both the input tensors, as well as targets (list of dictionary), containing:
- boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format.
- labels (`Int64Tensor[N]`): the class label for each ground truh box

CLI command::

# PascalVOC
python retinanet_module.py --gpus 1 --pretrained True
"""

def __init__(
self,
learning_rate: float = 0.0001,
num_classes: int = 91,
backbone: Optional[str] = None,
fpn: bool = True,
pretrained: bool = False,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
**kwargs: Any,
):
"""
Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
backbone: Pretained backbone CNN architecture.
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
super().__init__()
self.learning_rate = learning_rate
self.num_classes = num_classes
self.backbone = backbone
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
if backbone is None:
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
self.model = retinanet_resnet50_fpn(pretrained=pretrained, **kwargs)

self.model.head = RetinaNetHead(
in_channels=self.model.backbone.out_channels,
num_anchors=self.model.head.classification_head.num_anchors,
num_classes=num_classes,
**kwargs
)

else:
backbone_model = create_retinanet_backbone(
self.backbone, fpn, pretrained_backbone, trainable_backbone_layers, **kwargs
)
self.model = torchvision_RetinaNet(backbone_model, num_classes=num_classes, **kwargs)

def forward(self, x):
self.model.eval()
return self.model(x)

def training_step(self, batch, batch_idx):

images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]

# fasterrcnn takes both images and targets for training, returns
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
return {"loss": loss, "log": loss_dict}
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(o, t) for t, o in zip(targets, outs)]).mean()
return {"val_iou": iou}

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
logs = {"val_iou": avg_iou}
return {"avg_val_iou": avg_iou, "log": logs}

def configure_optimizers(self):
return torch.optim.SGD(
self.model.parameters(),
lr=self.learning_rate,
momentum=0.9,
weight_decay=0.005,
)

@staticmethod
def add_model_specific_args(parent_parser):
Borda marked this conversation as resolved.
Show resolved Hide resolved
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--num_classes", type=int, default=91)
parser.add_argument("--backbone", type=str, default=None)
parser.add_argument("--fpn", type=bool, default=True)
parser.add_argument("--pretrained", type=bool, default=False)
parser.add_argument("--pretrained_backbone", type=bool, default=True)
parser.add_argument("--trainable_backbone_layers", type=int, default=3)
return parser


def run_cli():
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
from pl_bolts.datamodules import VOCDetectionDataModule

pl.seed_everything(42)
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--batch_size", type=int, default=1)
parser = RetinaNet.add_model_specific_args(parser)

args = parser.parse_args()

datamodule = VOCDetectionDataModule.from_argparse_args(args)
args.num_classes = datamodule.num_classes

model = RetinaNet(**vars(args))
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
run_cli()
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.6
torch>=1.7
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
pytorch-lightning>=1.1.1
2 changes: 1 addition & 1 deletion requirements/models.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchvision>=0.7
torchvision>=0.8.1
scikit-learn>=0.23
Pillow
opencv-python
Expand Down
46 changes: 41 additions & 5 deletions tests/models/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from torch.utils.data import DataLoader

from pl_bolts.datasets import DummyDetectionDataset
from pl_bolts.models.detection import FasterRCNN
from pl_bolts.models.detection import FasterRCNN, RetinaNet


def _collate_fn(batch):
return tuple(zip(*batch))


@torch.no_grad()
def test_fasterrcnn():
model = FasterRCNN()

Expand All @@ -18,19 +19,54 @@ def test_fasterrcnn():


def test_fasterrcnn_train(tmpdir):
model = FasterRCNN()
model = FasterRCNN(pretrained=False)
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer = pl.Trainer(
fast_dev_run=True, logger=False, checkpoint_callback=False, max_epochs=1, default_root_dir=tmpdir
)
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl)


def test_fasterrcnn_bbone_train(tmpdir):
model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=True)
model = FasterRCNN(backbone="resnet18", fpn=True, pretrained_backbone=False)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = pl.Trainer(
fast_dev_run=True, logger=False, checkpoint_callback=False, max_epochs=1, default_root_dir=tmpdir
)
trainer.fit(model, train_dl, valid_dl)


@torch.no_grad()
def test_retinanet():
model = RetinaNet(pretrained=False)

image = torch.rand(1, 3, 400, 400)
model(image)


def test_retinanet_train(tmpdir):
model = RetinaNet(pretrained=False)

train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = pl.Trainer(
fast_dev_run=True, logger=False, checkpoint_callback=False, max_epochs=1, default_root_dir=tmpdir
)
trainer.fit(model, train_dataloader=train_dl, val_dataloaders=valid_dl)


def test_retinanet_bbone_train(tmpdir):
oke-aditya marked this conversation as resolved.
Show resolved Hide resolved
model = RetinaNet(backbone="resnet18", fpn=True, pretrained_backbone=False)
train_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)
valid_dl = DataLoader(DummyDetectionDataset(), collate_fn=_collate_fn)

trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer = pl.Trainer(
fast_dev_run=True, logger=False, checkpoint_callback=False, max_epochs=1, default_root_dir=tmpdir
)
trainer.fit(model, train_dl, valid_dl)