Skip to content

Commit

Permalink
Update src
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 30, 2023
1 parent 70b8cb6 commit dba6fa6
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 338 deletions.
3 changes: 0 additions & 3 deletions src/lightning_app/utilities/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ class LightningModuleVisitor(LightningVisitor):
"save_hyperparameters",
"test_step",
"test_step_end",
"test_epoch_end",
"to_onnx",
"to_torchscript",
"training_step",
"training_step_end",
"training_epoch_end",
"unfreeze",
"validation_step",
"validation_step_end",
"validation_epoch_end",
}

hooks: Set[str] = {
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed deadlock detection / process reconciliation (`PL_RECONCILE_PROCESS=1`) ([#16204](https://github.com/Lightning-AI/lightning/pull/16204))

- Removed the `{training,validation,test}_epoch_end` hooks which would retain step outputs in memory. Alternative implementations are suggested by implementing their `on_*_epoch_end` hooks instead ([#16520](https://github.com/Lightning-AI/lightning/pull/16520))

- Removed support for the experimental `PL_FAULT_TOLERANT_TRAINING` environment flag ([#16516](https://github.com/Lightning-AI/lightning/pull/16516), [#16533](https://github.com/Lightning-AI/lightning/pull/16533))

- Removed the deprecated `LightningCLI` arguments ([#16380](https://github.com/Lightning-AI/lightning/pull/16380))
Expand Down
25 changes: 22 additions & 3 deletions src/pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,29 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
:class:`pytorch_lightning.LightningModule` and access them in this hook:
1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
.. code-block:: python
class MyLightningModule(L.LightningModule):
def __init__(self):
super().__init__()
self.training_step_outputs = []
def training_step(self):
loss = ...
selgf.training_step_outputs.append(loss)
return loss
class MyCallback(L.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# do something with all training_step outputs, for example:
epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
pl_module.log("training_epoch_mean", epoch_mean)
# free up the memory
pl_module.training_step_outputs.clear()
"""

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down
23 changes: 20 additions & 3 deletions src/pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,27 @@ def on_train_epoch_start(self) -> None:
def on_train_epoch_end(self) -> None:
"""Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, either:
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
:class:`pytorch_lightning.LightningModule` and access them in this hook:
1. Implement `training_epoch_end` in the LightningModule OR
2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook
.. code-block:: python
class MyLightningModule(L.LightningModule):
def __init__(self):
super().__init__()
self.training_step_outputs = []
def training_step(self):
loss = ...
self.training_step_outputs.append(loss)
return loss
def on_train_epoch_end(self):
# do something with all training_step outputs, for example:
epoch_mean = torch.stack(self.training_step_outputs).mean()
self.log("training_epoch_mean", epoch_mean)
# free up the memory
self.training_step_outputs.clear()
"""

def on_validation_epoch_start(self) -> None:
Expand Down
154 changes: 1 addition & 153 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,7 @@
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import (
_METRIC,
EPOCH_OUTPUT,
LRSchedulerPLType,
LRSchedulerTypeUnion,
STEP_OUTPUT,
)
from pytorch_lightning.utilities.types import _METRIC, LRSchedulerPLType, LRSchedulerTypeUnion, STEP_OUTPUT

warning_cache = WarningCache()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -767,51 +761,11 @@ def training_step_end(self, training_step_outputs):
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
"""

def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Called at the end of the training epoch with the outputs of all training steps. Use this in case you
need to do something with all the outputs returned by :meth:`training_step`.
.. code-block:: python
# the pseudocode for these calls
train_outs = []
for train_batch in train_data:
out = training_step(train_batch)
train_outs.append(out)
training_epoch_end(train_outs)
Args:
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists
have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed.
Return:
None
Note:
If this method is not overridden, this won't be called.
.. code-block:: python
def training_epoch_end(self, training_step_outputs):
# do something with all training_step outputs
for out in training_step_outputs:
...
"""

def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
r"""
Operates on a single batch of data from the validation set.
In this step you'd might generate examples or calculate anything of interest like accuracy.
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
val_outs.append(out)
validation_epoch_end(val_outs)
Args:
batch: The output of your :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
Expand All @@ -825,13 +779,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
.. code-block:: python
# pseudocode of order
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
if defined("validation_step_end"):
out = validation_step_end(out)
val_outs.append(out)
val_outs = validation_epoch_end(val_outs)
.. code-block:: python
Expand Down Expand Up @@ -940,65 +891,12 @@ def validation_step_end(self, val_step_outputs):
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
"""

def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
"""Called at the end of the validation epoch with the outputs of all validation steps.
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
val_outs.append(out)
validation_epoch_end(val_outs)
Args:
outputs: List of outputs you defined in :meth:`validation_step`, or if there
are multiple dataloaders, a list containing a list of outputs for each dataloader.
Return:
None
Note:
If you didn't define a :meth:`validation_step`, this won't be called.
Examples:
With a single dataloader:
.. code-block:: python
def validation_epoch_end(self, val_step_outputs):
for out in val_step_outputs:
...
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
.. code-block:: python
def validation_epoch_end(self, outputs):
for dataloader_output_result in outputs:
dataloader_outs = dataloader_output_result.dataloader_i_outputs
self.log("final_metric", final_value)
"""

def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
r"""
Operates on a single batch of data from the test set.
In this step you'd normally generate examples or calculate anything of interest
such as accuracy.
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(test_batch)
test_outs.append(out)
test_epoch_end(test_outs)
Args:
batch: The output of your :class:`~torch.utils.data.DataLoader`.
batch_idx: The index of this batch.
Expand Down Expand Up @@ -1118,56 +1016,6 @@ def test_step_end(self, output_results):
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
"""

def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
"""Called at the end of a test epoch with the output of all test steps.
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(test_batch)
test_outs.append(out)
test_epoch_end(test_outs)
Args:
outputs: List of outputs you defined in :meth:`test_step_end`, or if there
are multiple dataloaders, a list containing a list of outputs for each dataloader
Return:
None
Note:
If you didn't define a :meth:`test_step`, this won't be called.
Examples:
With a single dataloader:
.. code-block:: python
def test_epoch_end(self, outputs):
# do something with the outputs of all test batches
all_test_preds = test_step_outputs.predictions
some_result = calc_all_results(all_test_preds)
self.log(some_result)
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each test step for that dataloader.
.. code-block:: python
def test_epoch_end(self, outputs):
final_value = 0
for dataloader_outputs in outputs:
for test_step_out in dataloader_outputs:
# do something
final_value += test_step_out
self.log("final_metric", final_value)
"""

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it
calls :meth:`~pytorch_lightning.core.module.LightningModule.forward`. Override to add any processing logic.
Expand Down
20 changes: 4 additions & 16 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import Dict, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -23,7 +23,7 @@
from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT


class RandomDictDataset(Dataset):
Expand Down Expand Up @@ -89,14 +89,14 @@ class TestModel(BoringModel):
def training_step(self, ...):
... # do your own thing
training_epoch_end = None # disable hook
training_step_end = None # disable hook
or
Example::
model = BoringModel()
model.training_epoch_end = None # disable hook
model.training_step_end = None # disable hook
"""
super().__init__()
self.layer = torch.nn.Linear(32, 2)
Expand All @@ -120,24 +120,12 @@ def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT:
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
return training_step_outputs

def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["loss"] for x in outputs]).mean()

def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
return {"x": self.step(batch)}

def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["x"] for x in outputs]).mean()

def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
return {"y": self.step(batch)}

def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()

def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
Expand Down
Loading

0 comments on commit dba6fa6

Please sign in to comment.