Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 30, 2023
1 parent d634846 commit 70b8cb6
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 112 deletions.
16 changes: 12 additions & 4 deletions docs/source-pytorch/accelerators/accelerator_prepare.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,27 @@ Note if you use any built in metrics or custom metrics that use `TorchMetrics <h

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:

.. testcode::
.. code-block:: python
def __init__(self):
super().__init__()
self.outputs = []
def test_step(self, batch, batch_idx):
x, y = batch
tensors = self(x)
self.outputs.append(tensors)
return tensors
def test_epoch_end(self, outputs):
mean = torch.mean(self.all_gather(outputs))
def on_test_epoch_end(self):
mean = torch.mean(self.all_gather(self.outputs))
self.outputs.clear() # free memory
# When logging only on rank 0, don't forget to add
# ``rank_zero_only=True`` to avoid deadlocks on synchronization.
# `rank_zero_only=True` to avoid deadlocks on synchronization.
# caveat: monitoring this is unimplemented. see https://github.com/Lightning-AI/lightning/issues/15852
if self.trainer.is_global_zero:
self.log("my_reduced_metric", mean, rank_zero_only=True)
Expand Down
135 changes: 43 additions & 92 deletions docs/source-pytorch/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,8 @@ Under the hood, Lightning does the following (pseudocode):
model.train()
torch.set_grad_enabled(True)
outs = []
for batch_idx, batch in enumerate(train_dataloader):
loss = training_step(batch, batch_idx)
outs.append(loss.detach())
# clear gradients
optimizer.zero_grad()
Expand Down Expand Up @@ -214,7 +212,7 @@ If you want to calculate epoch-level metrics and log them, use :meth:`~pytorch_l
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
The :meth:`~pytorch_lightning.core.module.LightningModule.log` object automatically reduces the
The :meth:`~pytorch_lightning.core.module.LightningModule.log` method automatically reduces the
requested metrics across a complete epoch and devices. Here's the pseudocode of what it does under the hood:

.. code-block:: python
Expand All @@ -223,59 +221,44 @@ requested metrics across a complete epoch and devices. Here's the pseudocode of
for batch_idx, batch in enumerate(train_dataloader):
# forward
loss = training_step(batch, batch_idx)
outs.append(loss)
outs.append(loss.detach())
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
epoch_metric = torch.mean(torch.stack([x for x in outs]))
# note: in reality, we do this incrementally, instead of keeping all outputs in memory
epoch_metric = torch.mean(torch.stack(outs))
Train Epoch-level Operations
============================

If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.module.LightningModule.training_step`,
override the :meth:`~pytorch_lightning.core.module.LightningModule.training_epoch_end` method.

.. code-block:: python
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
return {"loss": loss, "other_stuff": preds}
def training_epoch_end(self, training_step_outputs):
all_preds = torch.stack(training_step_outputs)
...
The matching pseudocode is:
override the :meth:`~pytorch_lightning.core.module.LightningModule.on_training_epoch_end` method.

.. code-block:: python
outs = []
for batch_idx, batch in enumerate(train_dataloader):
# forward
loss = training_step(batch, batch_idx)
outs.append(loss)
def __init__(self):
super().__init__()
self.training_step_outputs = []
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
self.training_step_outputs.append(preds)
return loss
# update parameters
optimizer.step()
training_epoch_end(outs)
def on_train_epoch_end(self):
all_preds = torch.stack(self.training_step_outputs)
self.training_step_outputs.clear() # free memory
...
Training with DataParallel
==========================
Expand Down Expand Up @@ -309,15 +292,10 @@ method which will have outputs from all the devices and you can accumulate to ge
return (losses[0] + losses[1]) / 2
def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
...
Here is the Lightning training pseudo-code for DP:

.. code-block:: python
outs = []
for batch_idx, train_batch in enumerate(train_dataloader):
batches = split_batch(train_batch)
dp_outs = []
Expand All @@ -327,12 +305,7 @@ Here is the Lightning training pseudo-code for DP:
dp_outs.append(dp_out)
# 2
out = training_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
training_epoch_end(outs)
training_step_end(dp_outs)
------------------

Expand Down Expand Up @@ -400,21 +373,30 @@ Validation Epoch-level Metrics
==============================

If you need to do something with all the outputs of each :meth:`~pytorch_lightning.core.module.LightningModule.validation_step`,
override the :meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end` method. Note that this method is called before :meth:`~pytorch_lightning.core.module.LightningModule.training_epoch_end`.
override the :meth:`~pytorch_lightning.core.module.LightningModule.on_validation_epoch_end` method.
Note that this method is called before :meth:`~pytorch_lightning.core.module.LightningModule.on_train_epoch_end`.

.. code-block:: python
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return pred
def __init__(self):
super().__init__()
self.validation_step_outputs = []
def validation_epoch_end(self, validation_step_outputs):
all_preds = torch.stack(validation_step_outputs)
...
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
self.validation_step_outputs.append(pred)
return pred
def on_validation_epoch_end(self):
all_preds = torch.stack(self.validation_step_outputs)
self.validation_step_outputs.clear() # free memory
...
Validating with DataParallel
============================
Expand Down Expand Up @@ -448,15 +430,10 @@ method which will have outputs from all the devices and you can accumulate to ge
return (losses[0] + losses[1]) / 2
def validation_epoch_end(self, validation_step_outputs):
for out in validation_step_outputs:
...
Here is the Lightning validation pseudo-code for DP:

.. code-block:: python
outs = []
for batch in dataloader:
batches = split_batch(batch)
dp_outs = []
Expand All @@ -466,12 +443,7 @@ Here is the Lightning validation pseudo-code for DP:
dp_outs.append(dp_out)
# 2
out = validation_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
validation_epoch_end(outs)
validation_step_end(dp_outs)
----------------

Expand Down Expand Up @@ -924,12 +896,6 @@ test_step_end
.. automethod:: pytorch_lightning.core.module.LightningModule.test_step_end
:noindex:

test_epoch_end
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.module.LightningModule.test_epoch_end
:noindex:

to_onnx
~~~~~~~

Expand All @@ -954,11 +920,6 @@ training_step_end
.. automethod:: pytorch_lightning.core.module.LightningModule.training_step_end
:noindex:

training_epoch_end
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.module.LightningModule.training_epoch_end
:noindex:

unfreeze
~~~~~~~~

Expand All @@ -983,12 +944,6 @@ validation_step_end
.. automethod:: pytorch_lightning.core.module.LightningModule.validation_step_end
:noindex:

validation_epoch_end
~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.module.LightningModule.validation_epoch_end
:noindex:

-----------

Properties
Expand Down Expand Up @@ -1247,7 +1202,8 @@ for more information.
transfer_batch_to_device()
on_after_batch_transfer()
training_step()
out = training_step()
training_step_end(out)
on_before_zero_grad()
optimizer_zero_grad()
Expand All @@ -1264,8 +1220,6 @@ for more information.
if should_check_val:
val_loop()
# end training epoch
training_epoch_end()
on_train_epoch_end()
Expand All @@ -1277,7 +1231,6 @@ for more information.
on_validation_start()
on_validation_epoch_start()
val_outs = []
for batch_idx, batch in enumerate(val_dataloader()):
on_validation_batch_start(batch, batch_idx)
Expand All @@ -1286,11 +1239,9 @@ for more information.
batch = on_after_batch_transfer(batch)
out = validation_step(batch, batch_idx)
out = validation_step_end(out)
on_validation_batch_end(batch, batch_idx)
val_outs.append(out)
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_validation_end()
Expand Down
4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ The :meth:`~pytorch_lightning.core.module.LightningModule.log` method has a few
* - Hook
- on_step
- on_epoch
* - on_train_start, on_train_epoch_start, on_train_epoch_end, training_epoch_end
* - on_train_start, on_train_epoch_start, on_train_epoch_end
- False
- True
* - on_before_backward, on_after_backward, on_before_optimizer_step, on_before_zero_grad
Expand All @@ -161,7 +161,7 @@ The :meth:`~pytorch_lightning.core.module.LightningModule.log` method has a few
* - on_train_batch_start, on_train_batch_end, training_step, training_step_end
- True
- False
* - on_validation_start, on_validation_epoch_start, on_validation_epoch_end, validation_epoch_end
* - on_validation_start, on_validation_epoch_start, on_validation_epoch_end
- False
- True
* - on_validation_batch_start, on_validation_batch_end, validation_step, validation_step_end
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/manual_optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ If you want to call schedulers that require a metric value after each epoch, con
self.automatic_optimization = False


def training_epoch_end(self, outputs):
def on_train_epoch_end(self):
sch = self.lr_schedulers()

# If the selected scheduler is a ReduceLROnPlateau scheduler.
Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/starter/style_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,19 @@ In practice, the code looks like this:
def training_step_end(...):
def training_epoch_end(...):
def on_train_epoch_end(...):
def validation_step(...):
def validation_step_end(...):
def validation_epoch_end(...):
def on_validation_epoch_end(...):
def test_step(...):
def test_step_end(...):
def test_epoch_end(...):
def on_test_epoch_end(...):
def configure_optimizers(...):
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/visualize/logging_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ In LightningModule
* - on_after_backward, on_before_backward, on_before_optimizer_step, optimizer_step, configure_gradient_clipping, on_before_zero_grad, training_step, training_step_end
- True
- False
* - training_epoch_end, test_epoch_end, test_step, test_step_end, validation_epoch_end, validation_step, validation_step_end
* - test_step, test_step_end, validation_step, validation_step_end
- False
- True

Expand Down
Loading

0 comments on commit 70b8cb6

Please sign in to comment.