Skip to content

Commit

Permalink
Revision models.detection.yolo (#851)
Browse files Browse the repository at this point in the history
Co-authored-by: heimish-kyma <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Hongyeob.Kim <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka B <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
11 people authored May 20, 2023
1 parent ab9e2e7 commit 8cb0a2d
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 47 deletions.
25 changes: 9 additions & 16 deletions src/pl_bolts/models/detection/yolo/yolo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from pl_bolts.models.detection.yolo import yolo_layers
from pl_bolts.utils.stability import under_review


@under_review()
class YOLOConfiguration:
"""This class can be used to parse the configuration files of the Darknet YOLOv4 implementation.
Expand Down Expand Up @@ -149,7 +147,6 @@ def convert(key, value):
return sections


@under_review()
def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]:
"""Calls one of the ``_create_<layertype>(config, num_inputs)`` functions to create a PyTorch module from the
layer config.
Expand All @@ -173,8 +170,7 @@ def _create_layer(config: dict, num_inputs: List[int]) -> Tuple[nn.Module, int]:
return create_func[config["type"]](config, num_inputs)


@under_review()
def _create_convolutional(config, num_inputs):
def _create_convolutional(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
module = nn.Sequential()

batch_normalize = config.get("batch_normalize", False)
Expand Down Expand Up @@ -210,15 +206,13 @@ def _create_convolutional(config, num_inputs):
return module, config["filters"]


@under_review()
def _create_maxpool(config, num_inputs):
def _create_maxpool(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
padding = (config["size"] - 1) // 2
module = nn.MaxPool2d(config["size"], config["stride"], padding)
return module, num_inputs[-1]


@under_review()
def _create_route(config, num_inputs):
def _create_route(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
num_chunks = config.get("groups", 1)
chunk_idx = config.get("group_id", 0)

Expand All @@ -234,20 +228,17 @@ def _create_route(config, num_inputs):
return module, num_outputs


@under_review()
def _create_shortcut(config, num_inputs):
def _create_shortcut(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
module = yolo_layers.ShortcutLayer(config["from"])
return module, num_inputs[-1]


@under_review()
def _create_upsample(config, num_inputs):
def _create_upsample(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
module = nn.Upsample(scale_factor=config["stride"], mode="nearest")
return module, num_inputs[-1]


@under_review()
def _create_yolo(config, num_inputs):
def _create_yolo(config: dict, num_inputs: int) -> Tuple[nn.Module, int]:
# The "anchors" list alternates width and height.
anchor_dims = config["anchors"]
anchor_dims = [(anchor_dims[i], anchor_dims[i + 1]) for i in range(0, len(anchor_dims), 2)]
Expand All @@ -264,8 +255,10 @@ def _create_yolo(config, num_inputs):
overlap_loss_func = yolo_layers.SELoss()
elif overlap_loss_name == "giou":
overlap_loss_func = yolo_layers.GIoULoss()
else:
elif overlap_loss_name == "iou":
overlap_loss_func = yolo_layers.IoULoss()
else:
raise ValueError("Unknown overlap loss: " + overlap_loss_name)

module = yolo_layers.DetectionLayer(
num_classes=config["classes"],
Expand Down
25 changes: 9 additions & 16 deletions src/pl_bolts/models/detection/yolo/yolo_layers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor, nn

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils import _TORCH_MESHGRID_REQUIRES_INDEXING, _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -21,7 +20,6 @@
warn_missing_pkg("torchvision")


@under_review()
def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor:
"""Converts box center points and sizes to corner coordinates.
Expand All @@ -38,7 +36,6 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor:
return torch.cat((top_left, bottom_right), -1)


@under_review()
def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor:
"""Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at
the same coordinates.
Expand All @@ -61,7 +58,6 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor:
return inter / union


@under_review()
class SELoss(nn.MSELoss):
def __init__(self):
super().__init__(reduction="none")
Expand All @@ -70,13 +66,11 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return super().forward(inputs, target).sum(1)


@under_review()
class IoULoss(nn.Module):
def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return 1.0 - box_iou(inputs, target).diagonal()


@under_review()
class GIoULoss(nn.Module):
def __init__(self) -> None:
super().__init__()
Expand All @@ -89,7 +83,6 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor:
return 1.0 - generalized_box_iou(inputs, target).diagonal()


@under_review()
class DetectionLayer(nn.Module):
"""A YOLO detection layer.
Expand Down Expand Up @@ -263,7 +256,10 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor:

x_range = torch.arange(width, device=xy.device)
y_range = torch.arange(height, device=xy.device)
grid_y, grid_x = torch.meshgrid(y_range, x_range)
if _TORCH_MESHGRID_REQUIRES_INDEXING:
grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij")
else:
grid_y, grid_x = torch.meshgrid(y_range, x_range)
offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2]
offset = offset.unsqueeze(2) # [height, width, 1, 2]

Expand Down Expand Up @@ -468,15 +464,13 @@ def _calculate_losses(
return losses, hits


@under_review()
class Mish(nn.Module):
"""Mish activation."""

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return x * torch.tanh(nn.functional.softplus(x))


@under_review()
class RouteLayer(nn.Module):
"""Route layer concatenates the output (or part of it) from given layers."""

Expand All @@ -492,12 +486,11 @@ def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) ->
self.num_chunks = num_chunks
self.chunk_idx = chunk_idx

def forward(self, x, outputs):
def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor:
chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers]
return torch.cat(chunks, dim=1)


@under_review()
class ShortcutLayer(nn.Module):
"""Shortcut layer adds a residual connection from the source layer."""

Expand All @@ -510,5 +503,5 @@ def __init__(self, source_layer: int) -> None:
super().__init__()
self.source_layer = source_layer

def forward(self, x, outputs):
def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor:
return outputs[-1] + outputs[self.source_layer]
10 changes: 3 additions & 7 deletions src/pl_bolts/models/detection/yolo/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pl_bolts.models.detection.yolo.yolo_layers import DetectionLayer, RouteLayer, ShortcutLayer
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -23,7 +22,6 @@
log = logging.getLogger(__name__)


@under_review()
class YOLO(LightningModule):
"""PyTorch Lightning implementation of YOLOv3 and YOLOv4.
Expand Down Expand Up @@ -179,7 +177,7 @@ def forward(
)
for layer_idx, layer_hits in enumerate(hits):
hit_rate = torch.true_divide(layer_hits, total_hits) if total_hits > 0 else 1.0
self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False)
self.log(f"layer_{layer_idx}_hit_rate", hit_rate, sync_dist=False, batch_size=images.size(0))

def total_loss(loss_name):
"""Returns the sum of the loss over detection layers."""
Expand Down Expand Up @@ -233,8 +231,8 @@ def validation_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], b
total_loss = torch.stack(tuple(losses.values())).sum()

for name, value in losses.items():
self.log(f"val/{name}_loss", value, sync_dist=True)
self.log("val/total_loss", total_loss, sync_dist=True)
self.log(f"val/{name}_loss", value, sync_dist=True, batch_size=images.size(0))
self.log("val/total_loss", total_loss, sync_dist=True, batch_size=images.size(0))

def test_step(self, batch: Tuple[List[Tensor], List[Dict[str, Tensor]]], batch_idx: int):
"""Evaluates a batch of data from the test set.
Expand Down Expand Up @@ -455,7 +453,6 @@ def _filter_detections(self, detections: Dict[str, Tensor]) -> Dict[str, List[Te
return {"boxes": out_boxes, "scores": out_scores, "classprobs": out_classprobs, "labels": out_labels}


@under_review()
class Resize:
"""Rescales the image and target to given dimensions.
Expand Down Expand Up @@ -486,7 +483,6 @@ def __call__(self, image: Tensor, target: Dict[str, Any]):
return image, target


@under_review()
def run_cli():
from argparse import ArgumentParser

Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore

_NATIVE_AMP_AVAILABLE: bool = module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

_TORCHVISION_AVAILABLE: bool = module_available("torchvision")
_GYM_AVAILABLE: bool = module_available("gym")
_SKLEARN_AVAILABLE: bool = module_available("sklearn")
Expand All @@ -20,6 +19,7 @@
_PL_GREATER_EQUAL_1_4_5 = compare_version("pytorch_lightning", operator.ge, "1.4.5")
_TORCH_ORT_AVAILABLE = module_available("torch_ort")
_TORCH_MAX_VERSION_SPARSEML = compare_version("torch", operator.lt, "1.11.0")
_TORCH_MESHGRID_REQUIRES_INDEXING = compare_version("torch", operator.ge, "1.10.0")
_SPARSEML_AVAILABLE = module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_SPARSEML
_JSONARGPARSE_GREATER_THAN_4_16_0 = compare_version("jsonargparse", operator.gt, "4.16.0")

Expand Down
81 changes: 81 additions & 0 deletions tests/data/yolo_giou.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
[net]
width=256
height=256
channels=3

[convolutional]
batch_normalize=1
filters=8
size=3
stride=1
pad=1
activation=leaky

[route]
layers=-1
groups=2
group_id=1

[maxpool]
size=2
stride=2

[convolutional]
batch_normalize=1
filters=2
size=1
stride=1
pad=1
activation=mish

[convolutional]
batch_normalize=1
filters=4
size=3
stride=1
pad=1
activation=mish

[shortcut]
from=-3
activation=linear

[convolutional]
size=1
stride=1
pad=1
filters=14
activation=linear

[yolo]
mask=2,3
anchors=1,2, 3,4, 5,6, 9,10
classes=2
iou_loss=giou
scale_x_y=1.05
cls_normalizer=1.0
iou_normalizer=0.07
ignore_thresh=0.7

[route]
layers = -4

[upsample]
stride=2

[convolutional]
size=1
stride=1
pad=1
filters=14
activation=linear

[yolo]
mask=0,1
anchors=1,2, 3,4, 5,6, 9,10
classes=2
iou_loss=giou
scale_x_y=1.05
cls_normalizer=1.0
iou_normalizer=0.07
ignore_thresh=0.7
Loading

0 comments on commit 8cb0a2d

Please sign in to comment.