-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revision models.detection.yolo #851
Changes from 26 commits
cf1646c
fecf88c
b5abc8f
198ebc1
08f17f7
db2601a
fe38bb7
7da9d4a
8c3ed4e
8f69419
b25a864
9ff86ab
17fab64
353f119
a3445ac
189346c
a1d97b6
d5b5fb9
0b4eca4
b52ab5b
eb9930e
fdf38fb
a42cdec
d534cfa
57c9baf
55059f5
bf5b360
3ed65fd
bd23c27
41d2749
0d1c4e7
d758939
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
from typing import Callable, Dict, List, Optional, Tuple | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from torch import Tensor, nn | ||
|
||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.stability import under_review | ||
from pl_bolts.utils import _TORCH_MESHGRID_REQUIRES_INDEXING, _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
|
@@ -21,7 +20,6 @@ | |
warn_missing_pkg("torchvision") | ||
|
||
|
||
@under_review() | ||
def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: | ||
"""Converts box center points and sizes to corner coordinates. | ||
|
||
|
@@ -38,7 +36,6 @@ def _corner_coordinates(xy: Tensor, wh: Tensor) -> Tensor: | |
return torch.cat((top_left, bottom_right), -1) | ||
|
||
|
||
@under_review() | ||
def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: | ||
"""Calculates a matrix of intersections over union from box dimensions, assuming that the boxes are located at | ||
the same coordinates. | ||
|
@@ -61,7 +58,6 @@ def _aligned_iou(dims1: Tensor, dims2: Tensor) -> Tensor: | |
return inter / union | ||
|
||
|
||
@under_review() | ||
class SELoss(nn.MSELoss): | ||
def __init__(self): | ||
super().__init__(reduction="none") | ||
|
@@ -70,13 +66,11 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: | |
return super().forward(inputs, target).sum(1) | ||
|
||
|
||
@under_review() | ||
class IoULoss(nn.Module): | ||
def forward(self, inputs: Tensor, target: Tensor) -> Tensor: | ||
return 1.0 - box_iou(inputs, target).diagonal() | ||
|
||
|
||
@under_review() | ||
class GIoULoss(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
@@ -89,7 +83,6 @@ def forward(self, inputs: Tensor, target: Tensor) -> Tensor: | |
return 1.0 - generalized_box_iou(inputs, target).diagonal() | ||
|
||
|
||
@under_review() | ||
class DetectionLayer(nn.Module): | ||
"""A YOLO detection layer. | ||
|
||
|
@@ -263,7 +256,10 @@ def _global_xy(self, xy: Tensor, image_size: Tensor) -> Tensor: | |
|
||
x_range = torch.arange(width, device=xy.device) | ||
y_range = torch.arange(height, device=xy.device) | ||
grid_y, grid_x = torch.meshgrid(y_range, x_range) | ||
if _TORCH_MESHGRID_REQUIRES_INDEXING: | ||
grid_y, grid_x = torch.meshgrid(y_range, x_range, indexing="ij") | ||
else: | ||
grid_y, grid_x = torch.meshgrid(y_range, x_range) | ||
offset = torch.stack((grid_x, grid_y), -1) # [height, width, 2] | ||
offset = offset.unsqueeze(2) # [height, width, 1, 2] | ||
|
||
|
@@ -468,15 +464,13 @@ def _calculate_losses( | |
return losses, hits | ||
|
||
|
||
@under_review() | ||
class Mish(nn.Module): | ||
"""Mish activation.""" | ||
|
||
def forward(self, x): | ||
def forward(self, x: Tensor) -> Tensor: | ||
return x * torch.tanh(nn.functional.softplus(x)) | ||
|
||
|
||
@under_review() | ||
class RouteLayer(nn.Module): | ||
"""Route layer concatenates the output (or part of it) from given layers.""" | ||
|
||
|
@@ -492,12 +486,11 @@ def __init__(self, source_layers: List[int], num_chunks: int, chunk_idx: int) -> | |
self.num_chunks = num_chunks | ||
self.chunk_idx = chunk_idx | ||
|
||
def forward(self, x, outputs): | ||
def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with this 'x' here! |
||
chunks = [torch.chunk(outputs[layer], self.num_chunks, dim=1)[self.chunk_idx] for layer in self.source_layers] | ||
return torch.cat(chunks, dim=1) | ||
|
||
|
||
@under_review() | ||
class ShortcutLayer(nn.Module): | ||
"""Shortcut layer adds a residual connection from the source layer.""" | ||
|
||
|
@@ -510,5 +503,5 @@ def __init__(self, source_layer: int) -> None: | |
super().__init__() | ||
self.source_layer = source_layer | ||
|
||
def forward(self, x, outputs): | ||
def forward(self, x, outputs: List[Union[Tensor, None]]) -> Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could review why we are passing 'x' to the forward method and doing nothing with it. Seems to be just to keep with the format... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luca-medeiros you're right. RouteLayer and ShortcutLayer do not use |
||
return outputs[-1] + outputs[self.source_layer] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
[net] | ||
width=256 | ||
height=256 | ||
channels=3 | ||
|
||
[convolutional] | ||
batch_normalize=1 | ||
filters=8 | ||
size=3 | ||
stride=1 | ||
pad=1 | ||
activation=leaky | ||
|
||
[route] | ||
layers=-1 | ||
groups=2 | ||
group_id=1 | ||
|
||
[maxpool] | ||
size=2 | ||
stride=2 | ||
|
||
[convolutional] | ||
batch_normalize=1 | ||
filters=2 | ||
size=1 | ||
stride=1 | ||
pad=1 | ||
activation=mish | ||
|
||
[convolutional] | ||
batch_normalize=1 | ||
filters=4 | ||
size=3 | ||
stride=1 | ||
pad=1 | ||
activation=mish | ||
|
||
[shortcut] | ||
from=-3 | ||
activation=linear | ||
|
||
[convolutional] | ||
size=1 | ||
stride=1 | ||
pad=1 | ||
filters=14 | ||
activation=linear | ||
|
||
[yolo] | ||
mask=2,3 | ||
anchors=1,2, 3,4, 5,6, 9,10 | ||
classes=2 | ||
iou_loss=giou | ||
scale_x_y=1.05 | ||
cls_normalizer=1.0 | ||
iou_normalizer=0.07 | ||
ignore_thresh=0.7 | ||
|
||
[route] | ||
layers = -4 | ||
|
||
[upsample] | ||
stride=2 | ||
|
||
[convolutional] | ||
size=1 | ||
stride=1 | ||
pad=1 | ||
filters=14 | ||
activation=linear | ||
|
||
[yolo] | ||
mask=0,1 | ||
anchors=1,2, 3,4, 5,6, 9,10 | ||
classes=2 | ||
iou_loss=giou | ||
scale_x_y=1.05 | ||
cls_normalizer=1.0 | ||
iou_normalizer=0.07 | ||
ignore_thresh=0.7 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if that's the desired design. Maybe rather than raising an error, we could set IoU as default and raise a warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the code six lines above, the default should be
mse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was based on the original YOLO implementation, which I think uses MSE loss. That's why "mse" was the default. Darknet configuration files allow other variants of the iou loss (diou, ciou) that were not available in Torchvision at the time of writing this code, so I chose to use IoULoss in case something else than "mse" or "giou" is specified in the configuration file. I created another pull request a while ago that adds support for the ciou and diou losses, that are now available in Torchvision.