Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semantic segmentation model #259

Merged
merged 76 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
aa8c782
unet implementation
annikabrundyn Sep 23, 2020
6e15480
unet
annikabrundyn Sep 23, 2020
1b5ef8e
kitti dataset
annikabrundyn Sep 23, 2020
5468f58
kitti dataset
annikabrundyn Sep 23, 2020
bd87f18
kitti dm
annikabrundyn Sep 23, 2020
18c1c4f
kitti dm
annikabrundyn Sep 23, 2020
ba4c9a6
imports
annikabrundyn Sep 23, 2020
052e02c
kitti
annikabrundyn Sep 23, 2020
141adae
kitti
annikabrundyn Sep 23, 2020
d37b685
kitti
annikabrundyn Sep 23, 2020
bb15b0a
kitti
annikabrundyn Sep 23, 2020
382fb7d
kitti
annikabrundyn Sep 23, 2020
583147f
kitti
annikabrundyn Sep 23, 2020
5796819
kitti
annikabrundyn Sep 23, 2020
dfd5560
kitti
annikabrundyn Sep 23, 2020
9e8a174
kitti
annikabrundyn Sep 23, 2020
58523b0
kitti
annikabrundyn Sep 23, 2020
1c47a5d
kitti
annikabrundyn Sep 23, 2020
7db239b
kitti
annikabrundyn Sep 23, 2020
2c9bdfb
kitti
annikabrundyn Sep 23, 2020
d1fed42
kitti
annikabrundyn Sep 23, 2020
f89cfd7
kitti
annikabrundyn Sep 23, 2020
1196f11
kitti
annikabrundyn Sep 23, 2020
21ceed2
kitti
annikabrundyn Sep 23, 2020
7a76535
kitti
annikabrundyn Sep 23, 2020
a652727
kitti
annikabrundyn Sep 23, 2020
598a5d7
kitti
annikabrundyn Sep 23, 2020
8ac3bbf
kitti
annikabrundyn Sep 23, 2020
8f6fddb
kitti
annikabrundyn Sep 23, 2020
a0fa356
kitti
annikabrundyn Sep 23, 2020
4ba33cd
clean up
annikabrundyn Sep 23, 2020
04d7bfa
clean up
annikabrundyn Sep 23, 2020
4e52683
Merge branch 'unet' into segment
annikabrundyn Sep 23, 2020
03a1c8e
Merge branch 'kitti_dm' into segment
annikabrundyn Sep 23, 2020
8569494
semantic segmentation example
annikabrundyn Sep 23, 2020
e27f76d
example
annikabrundyn Sep 23, 2020
d8e306a
example
annikabrundyn Sep 23, 2020
9a27dd1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
annikabrundyn Oct 2, 2020
96fd2c6
add kitti dm
annikabrundyn Oct 2, 2020
bf0c294
model
annikabrundyn Oct 2, 2020
816c5ef
kitti
annikabrundyn Oct 2, 2020
a454130
add test
annikabrundyn Oct 2, 2020
4e07266
add test
annikabrundyn Oct 2, 2020
a6aca44
clean up
annikabrundyn Oct 2, 2020
3db7940
clean up
annikabrundyn Oct 2, 2020
83c1f41
clean up
annikabrundyn Oct 2, 2020
bd9cccb
clean up
annikabrundyn Oct 2, 2020
e646bc9
unet docs
annikabrundyn Oct 6, 2020
78ea7a7
segment docs
annikabrundyn Oct 6, 2020
1241c68
rename file
annikabrundyn Oct 6, 2020
8d36d43
formatting
annikabrundyn Oct 6, 2020
5c73757
formatting
annikabrundyn Oct 6, 2020
b2d806f
formatting
annikabrundyn Oct 6, 2020
b5fd92b
formatting
annikabrundyn Oct 6, 2020
77e9a27
fix tests
annikabrundyn Oct 6, 2020
b0126e1
formatting
annikabrundyn Oct 6, 2020
2a7c1d6
Merge branch 'master' into segment
annikabrundyn Oct 7, 2020
08e7499
conflicts
annikabrundyn Oct 7, 2020
93f616d
fix conflict
annikabrundyn Oct 7, 2020
c0430d0
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
annikabrundyn Oct 7, 2020
539284c
fix conflicts
annikabrundyn Oct 7, 2020
32efe6e
remove new line
annikabrundyn Oct 7, 2020
0186b49
remove extra line
annikabrundyn Oct 7, 2020
39589f3
fix import
annikabrundyn Oct 7, 2020
3464e8e
fix import
annikabrundyn Oct 7, 2020
ad2140b
fix import
annikabrundyn Oct 7, 2020
9b382d1
docs
annikabrundyn Oct 13, 2020
83004b8
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
annikabrundyn Oct 13, 2020
a52783a
Merge branch 'master' into segment
williamFalcon Oct 14, 2020
eb6715c
Merge branch 'master' into segment
annikabrundyn Oct 15, 2020
35e8f10
conflict fix
annikabrundyn Oct 16, 2020
ea253a2
conflict fix
annikabrundyn Oct 16, 2020
dece2fc
fix conflicts
annikabrundyn Oct 16, 2020
4187220
fix conflicts
annikabrundyn Oct 16, 2020
6912304
fix conflicts
annikabrundyn Oct 16, 2020
846beea
update test
annikabrundyn Oct 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/source/convolutional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,32 @@ Pixel CNN

.. autoclass:: pl_bolts.models.vision.pixel_cnn.PixelCNN
:noindex:

-------------

UNet
----

.. autoclass:: pl_bolts.models.vision.unet.UNet
:noindex:

-------------

Semantic Segmentation
---------------------
Model template to use for semantic segmentation tasks. The model uses a UNet architecture by default. Override any part
of this model to build your own variation.

.. code-block:: python

from pl_bolts.models.vision import SemSegment
from pl_bolts.datamodules import KittiDataModule
import pytorch_lightning as pl

dm = KittiDataModule('path/to/kitt/dataset/', batch_size=4)
model = SemSegment(datamodule=dm)
trainer = pl.Trainer()
trainer.fit(model)

.. autoclass:: pl_bolts.models.vision.segmentation.SemSegment
:noindex:
1 change: 1 addition & 0 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.vision import PixelCNN
from pl_bolts.models.vision import UNet
from pl_bolts.models.vision import SemSegment
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT
1 change: 1 addition & 0 deletions pl_bolts/models/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from pl_bolts.models.vision.pixel_cnn import PixelCNN
from pl_bolts.models.vision.unet import UNet
from pl_bolts.models.vision.segmentation import SemSegment
123 changes: 123 additions & 0 deletions pl_bolts/models/vision/segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from argparse import ArgumentParser, Namespace

import pytorch_lightning as pl
import torch
import torch.nn.functional as F

from pl_bolts.models.vision.unet import UNet


class SemSegment(pl.LightningModule):
def __init__(
self,
datamodule: pl.LightningDataModule = None,
lr: float = 0.01,
num_classes: int = 19,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
"""
Basic model for semantic segmentation. Uses UNet architecture by default.

The default parameters in this model are for the KITTI dataset. Note, if you'd like to use this model as is,
you will first need to download the KITTI dataset yourself. You can download the dataset `here.
<http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015>`_

Implemented by:

- `Annika Brundyn <https://github.com/annikabrundyn>`_

Args:
datamodule: LightningDataModule
num_layers: number of layers in each side of U-net (default 5)
features_start: number of features in first layer (default 64)
bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
lr: learning (default 0.01)
"""
super().__init__()

assert datamodule
self.datamodule = datamodule

self.num_classes = num_classes
self.num_layers = num_layers
self.features_start = features_start
self.bilinear = bilinear
self.lr = lr

self.net = UNet(num_classes=num_classes,
num_layers=self.num_layers,
features_start=self.features_start,
bilinear=self.bilinear)

def forward(self, x):
return self.net(x)

def training_step(self, batch, batch_nb):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
log_dict = {'train_loss': loss_val}
return {'loss': loss_val, 'log': log_dict, 'progress_bar': log_dict}

def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self(img)
loss_val = F.cross_entropy(out, mask, ignore_index=250)
return {'val_loss': loss_val}

def validation_epoch_end(self, outputs):
loss_val = torch.stack([x['val_loss'] for x in outputs]).mean()
log_dict = {'val_loss': loss_val}
return {'log': log_dict, 'val_loss': log_dict['val_loss'], 'progress_bar': log_dict}

def configure_optimizers(self):
opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sch]

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate")
parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
parser.add_argument("--bilinear", action='store_true', default=False,
help="whether to use bilinear interpolation or transposed")

return parser


def cli_main():
from pl_bolts.datamodules import KittiDataModule

pl.seed_everything(1234)

parser = ArgumentParser()

# trainer args
parser = pl.Trainer.add_argparse_args(parser)

# model args
parser = SemSegment.add_model_specific_args(parser)
args = parser.parse_args()

# data
dm = KittiDataModule(args.data_dir).from_argparse_args(args)

# model
model = SemSegment(**args.__dict__, datamodule=dm)

# train
trainer = pl.Trainer().from_argparse_args(args)
trainer.fit(model)


if __name__ == '__main__':
cli_main()
19 changes: 8 additions & 11 deletions pl_bolts/models/vision/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,29 @@

class UNet(nn.Module):
"""
PyTorch Lightning implementation of `U-Net: Convolutional Networks for Biomedical Image Segmentation
Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
<https://arxiv.org/abs/1505.04597>`_

Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox

Model implemented by:
Implemented by:

- `Annika Brundyn <https://github.com/annikabrundyn>`_
- `Akshay Kulkarni <https://github.com/akshaykvnit>`_

.. warning:: Work in progress. This implementation is still being verified.
Args:
num_classes: Number of output classes required
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""

def __init__(
self,
num_classes: int,
num_layers: int = 5,
features_start: int = 64,
bilinear: bool = False
):
"""
Args:
num_classes: Number of output classes required
num_layers: Number of layers in each side of U-net (default 5)
features_start: Number of features in first layer (default 64)
bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
"""
super().__init__()
self.num_layers = num_layers

Expand Down
22 changes: 21 additions & 1 deletion tests/models/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import torch

from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule
from pl_bolts.models import GPT2, ImageGPT, UNet
from pl_bolts.datasets import DummyDataset
from pl_bolts.models import GPT2, ImageGPT, UNet, SemSegment
from torch.utils.data import DataLoader


def test_igpt(tmpdir):
Expand Down Expand Up @@ -54,3 +56,21 @@ def test_unet(tmpdir):
model = UNet(num_classes=2)
y = model(x)
assert y.shape == torch.Size([10, 2, 28, 28])


def test_semantic_segmentation(tmpdir):

class DummyDataModule(pl.LightningDataModule):
def train_dataloader(self):
train_ds = DummyDataset((3, 35, 120), (35, 120), num_samples=100)
return DataLoader(train_ds, batch_size=1)

dm = DummyDataModule()

model = SemSegment(datamodule=dm, num_classes=19)

trainer = pl.Trainer(fast_dev_run=True, max_epochs=1)
trainer.fit(model)
loss = trainer.progress_bar_dict['loss']

assert float(loss) > 0