diff --git a/src/lightning_app/utilities/introspection.py b/src/lightning_app/utilities/introspection.py index b1fb5da948a95..556457e86b9e0 100644 --- a/src/lightning_app/utilities/introspection.py +++ b/src/lightning_app/utilities/introspection.py @@ -79,16 +79,13 @@ class LightningModuleVisitor(LightningVisitor): "save_hyperparameters", "test_step", "test_step_end", - "test_epoch_end", "to_onnx", "to_torchscript", "training_step", "training_step_end", - "training_epoch_end", "unfreeze", "validation_step", "validation_step_end", - "validation_epoch_end", } hooks: Set[str] = { diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0c3ca10abe952..dfd714c5170c7 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deadlock detection / process reconciliation (`PL_RECONCILE_PROCESS=1`) ([#16204](https://github.com/Lightning-AI/lightning/pull/16204)) +- Removed the `{training,validation,test}_epoch_end` hooks which would retain step outputs in memory. Alternative implementations are suggested by implementing their `on_*_epoch_end` hooks instead ([#16520](https://github.com/Lightning-AI/lightning/pull/16520)) + - Removed support for the experimental `PL_FAULT_TOLERANT_TRAINING` environment flag ([#16516](https://github.com/Lightning-AI/lightning/pull/16516), [#16533](https://github.com/Lightning-AI/lightning/pull/16533)) - Removed the deprecated `LightningCLI` arguments ([#16380](https://github.com/Lightning-AI/lightning/pull/16380)) diff --git a/src/pytorch_lightning/callbacks/callback.py b/src/pytorch_lightning/callbacks/callback.py index d8cfdb5399ca6..1716c5f3becc2 100644 --- a/src/pytorch_lightning/callbacks/callback.py +++ b/src/pytorch_lightning/callbacks/callback.py @@ -95,10 +95,29 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train epoch ends. - To access all batch outputs at the end of the epoch, either: + To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the + :class:`pytorch_lightning.LightningModule` and access them in this hook: - 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR - 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook. + .. code-block:: python + + class MyLightningModule(L.LightningModule): + def __init__(self): + super().__init__() + self.training_step_outputs = [] + + def training_step(self): + loss = ... + selgf.training_step_outputs.append(loss) + return loss + + + class MyCallback(L.Callback): + def on_train_epoch_end(self, trainer, pl_module): + # do something with all training_step outputs, for example: + epoch_mean = torch.stack(pl_module.training_step_outputs).mean() + pl_module.log("training_epoch_mean", epoch_mean) + # free up the memory + pl_module.training_step_outputs.clear() """ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/src/pytorch_lightning/core/hooks.py b/src/pytorch_lightning/core/hooks.py index 025e9bb74c5ca..8d417ec0beb33 100644 --- a/src/pytorch_lightning/core/hooks.py +++ b/src/pytorch_lightning/core/hooks.py @@ -169,10 +169,27 @@ def on_train_epoch_start(self) -> None: def on_train_epoch_end(self) -> None: """Called in the training loop at the very end of the epoch. - To access all batch outputs at the end of the epoch, either: + To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the + :class:`pytorch_lightning.LightningModule` and access them in this hook: - 1. Implement `training_epoch_end` in the LightningModule OR - 2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook + .. code-block:: python + + class MyLightningModule(L.LightningModule): + def __init__(self): + super().__init__() + self.training_step_outputs = [] + + def training_step(self): + loss = ... + self.training_step_outputs.append(loss) + return loss + + def on_train_epoch_end(self): + # do something with all training_step outputs, for example: + epoch_mean = torch.stack(self.training_step_outputs).mean() + self.log("training_epoch_mean", epoch_mean) + # free up the memory + self.training_step_outputs.clear() """ def on_validation_epoch_start(self) -> None: diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index eac840613a7fd..82676d7ca2e54 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -50,13 +50,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1 from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import ( - _METRIC, - EPOCH_OUTPUT, - LRSchedulerPLType, - LRSchedulerTypeUnion, - STEP_OUTPUT, -) +from pytorch_lightning.utilities.types import _METRIC, LRSchedulerPLType, LRSchedulerTypeUnion, STEP_OUTPUT warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -767,51 +761,11 @@ def training_step_end(self, training_step_outputs): See the :ref:`Multi GPU Training ` guide for more details. """ - def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - """Called at the end of the training epoch with the outputs of all training steps. Use this in case you - need to do something with all the outputs returned by :meth:`training_step`. - - .. code-block:: python - - # the pseudocode for these calls - train_outs = [] - for train_batch in train_data: - out = training_step(train_batch) - train_outs.append(out) - training_epoch_end(train_outs) - - Args: - outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists - have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed. - - Return: - None - - Note: - If this method is not overridden, this won't be called. - - .. code-block:: python - - def training_epoch_end(self, training_step_outputs): - # do something with all training_step outputs - for out in training_step_outputs: - ... - """ - def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. - .. code-block:: python - - # the pseudocode for these calls - val_outs = [] - for val_batch in val_data: - out = validation_step(val_batch) - val_outs.append(out) - validation_epoch_end(val_outs) - Args: batch: The output of your :class:`~torch.utils.data.DataLoader`. batch_idx: The index of this batch. @@ -825,13 +779,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: .. code-block:: python # pseudocode of order - val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) - val_outs.append(out) - val_outs = validation_epoch_end(val_outs) .. code-block:: python @@ -940,65 +891,12 @@ def validation_step_end(self, val_step_outputs): See the :ref:`Multi GPU Training ` guide for more details. """ - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - """Called at the end of the validation epoch with the outputs of all validation steps. - - .. code-block:: python - - # the pseudocode for these calls - val_outs = [] - for val_batch in val_data: - out = validation_step(val_batch) - val_outs.append(out) - validation_epoch_end(val_outs) - - Args: - outputs: List of outputs you defined in :meth:`validation_step`, or if there - are multiple dataloaders, a list containing a list of outputs for each dataloader. - - Return: - None - - Note: - If you didn't define a :meth:`validation_step`, this won't be called. - - Examples: - With a single dataloader: - - .. code-block:: python - - def validation_epoch_end(self, val_step_outputs): - for out in val_step_outputs: - ... - - With multiple dataloaders, `outputs` will be a list of lists. The outer list contains - one entry per dataloader, while the inner list contains the individual outputs of - each validation step for that dataloader. - - .. code-block:: python - - def validation_epoch_end(self, outputs): - for dataloader_output_result in outputs: - dataloader_outs = dataloader_output_result.dataloader_i_outputs - - self.log("final_metric", final_value) - """ - def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy. - .. code-block:: python - - # the pseudocode for these calls - test_outs = [] - for test_batch in test_data: - out = test_step(test_batch) - test_outs.append(out) - test_epoch_end(test_outs) - Args: batch: The output of your :class:`~torch.utils.data.DataLoader`. batch_idx: The index of this batch. @@ -1118,56 +1016,6 @@ def test_step_end(self, output_results): See the :ref:`Multi GPU Training ` guide for more details. """ - def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - """Called at the end of a test epoch with the output of all test steps. - - .. code-block:: python - - # the pseudocode for these calls - test_outs = [] - for test_batch in test_data: - out = test_step(test_batch) - test_outs.append(out) - test_epoch_end(test_outs) - - Args: - outputs: List of outputs you defined in :meth:`test_step_end`, or if there - are multiple dataloaders, a list containing a list of outputs for each dataloader - - Return: - None - - Note: - If you didn't define a :meth:`test_step`, this won't be called. - - Examples: - With a single dataloader: - - .. code-block:: python - - def test_epoch_end(self, outputs): - # do something with the outputs of all test batches - all_test_preds = test_step_outputs.predictions - - some_result = calc_all_results(all_test_preds) - self.log(some_result) - - With multiple dataloaders, `outputs` will be a list of lists. The outer list contains - one entry per dataloader, while the inner list contains the individual outputs of - each test step for that dataloader. - - .. code-block:: python - - def test_epoch_end(self, outputs): - final_value = 0 - for dataloader_outputs in outputs: - for test_step_out in dataloader_outputs: - # do something - final_value += test_step_out - - self.log("final_metric", final_value) - """ - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: """Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it calls :meth:`~pytorch_lightning.core.module.LightningModule.forward`. Override to add any processing logic. diff --git a/src/pytorch_lightning/demos/boring_classes.py b/src/pytorch_lightning/demos/boring_classes.py index d125d494ddedb..0a9aff740d81d 100644 --- a/src/pytorch_lightning/demos/boring_classes.py +++ b/src/pytorch_lightning/demos/boring_classes.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -23,7 +23,7 @@ from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER from pytorch_lightning import LightningDataModule, LightningModule from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT class RandomDictDataset(Dataset): @@ -89,14 +89,14 @@ class TestModel(BoringModel): def training_step(self, ...): ... # do your own thing - training_epoch_end = None # disable hook + training_step_end = None # disable hook or Example:: model = BoringModel() - model.training_epoch_end = None # disable hook + model.training_step_end = None # disable hook """ super().__init__() self.layer = torch.nn.Linear(32, 2) @@ -120,24 +120,12 @@ def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT: return training_step_outputs - def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - outputs = cast(List[Dict[str, Tensor]], outputs) - torch.stack([x["loss"] for x in outputs]).mean() - def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: return {"x": self.step(batch)} - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - outputs = cast(List[Dict[str, Tensor]], outputs) - torch.stack([x["x"] for x in outputs]).mean() - def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: return {"y": self.step(batch)} - def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - outputs = cast(List[Dict[str, Tensor]], outputs) - torch.stack([x["y"] for x in outputs]).mean() - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]: optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index cac778fdfcbdd..66fd79368a78a 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -31,7 +31,6 @@ from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import EPOCH_OUTPUT if _RICH_AVAILABLE: from rich import get_console @@ -51,7 +50,6 @@ def __init__(self, verbose: bool = True) -> None: self.verbose = verbose self._results = _ResultCollection(training=False) - self._outputs: List[EPOCH_OUTPUT] = [] self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False @@ -113,7 +111,6 @@ def reset(self) -> None: """Resets the internal state of the loop.""" self._max_batches = self._get_max_batches() # bookkeeping - self._outputs = [] self._logged_outputs = [] if isinstance(self._max_batches, int): @@ -154,10 +151,7 @@ def batch_to_device(batch: Any) -> Any: kwargs = OrderedDict() if self.num_dataloaders > 1: kwargs["dataloader_idx"] = dataloader_idx - dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) - - # store batch level output per dataloader - self._outputs.append(dl_outputs) + self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) if not self.trainer.sanity_checking: # indicate the loop has run @@ -180,10 +174,7 @@ def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` self.trainer._logger_connector.epoch_end_reached() - - # hook - self._evaluation_epoch_end(self._outputs) - self._outputs = [] # free memory + self.trainer._logger_connector._evaluation_epoch_end() # hook self._on_evaluation_epoch_end() @@ -216,7 +207,6 @@ def teardown(self) -> None: self._data_fetcher.teardown() self._data_fetcher = None self._results.cpu() - self.epoch_loop.teardown() def _get_max_batches(self) -> List[Union[int, float]]: """Returns the max number of batches for each dataloader.""" @@ -279,19 +269,6 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self.trainer._call_callback_hooks(hook_name, *args, **kwargs) self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs) - def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None: - """Runs ``{validation/test}_epoch_end``""" - self.trainer._logger_connector._evaluation_epoch_end() - - # with a single dataloader don't pass a 2D list - output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = ( - outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs - ) - - # call the model epoch end - hook_name = "test_epoch_end" if self.trainer.testing else "validation_epoch_end" - self.trainer._call_lightning_module_hook(hook_name, output_or_outputs) - def _on_evaluation_epoch_end(self) -> None: """Runs ``on_{validation/test}_epoch_end`` hook.""" hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" diff --git a/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 58547d6a44fa3..f31c94ed697ca 100644 --- a/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import OrderedDict -from functools import lru_cache from typing import Any, Optional, Union from pytorch_lightning.loops.loop import _Loop @@ -21,8 +20,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import SIGTERMException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT class _EvaluationEpochLoop(_Loop): @@ -36,7 +34,6 @@ def __init__(self) -> None: super().__init__() self.batch_progress = BatchProgress() - self._outputs: EPOCH_OUTPUT = [] self._dl_max_batches: Union[int, float] = 0 self._data_fetcher: Optional[AbstractDataFetcher] = None self._dl_batch_idx = [0] @@ -46,9 +43,7 @@ def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches - def run( - self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict - ) -> EPOCH_OUTPUT: + def run(self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: self.reset() self.on_run_start(data_fetcher, dl_max_batches, kwargs) while not self.done: @@ -58,13 +53,12 @@ def run( except StopIteration: break self._restarting = False - return self.on_run_end() + self.on_run_end() def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = 0 self._data_fetcher = None - self._outputs = [] if not self.restarting: self.batch_progress.reset_on_run() @@ -154,22 +148,11 @@ def advance( self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx]) self._dl_batch_idx[dataloader_idx] += 1 - # track epoch level outputs - if self._should_track_batch_outputs_for_epoch_end() and output is not None: - self._outputs.append(output) - if not self.batch_progress.is_last_batch and self.trainer.received_sigterm: raise SIGTERMException - def on_run_end(self) -> EPOCH_OUTPUT: - """Returns the outputs of the whole run.""" - outputs, self._outputs = self._outputs, [] # free memory + def on_run_end(self) -> None: self._data_fetcher = None - return outputs - - def teardown(self) -> None: - # in case the model changes - self._should_track_batch_outputs_for_epoch_end.cache_clear() def _num_completed_batches_reached(self) -> bool: epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches @@ -253,13 +236,5 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde kwargs.move_to_end("batch", last=False) return kwargs - @lru_cache(1) - def _should_track_batch_outputs_for_epoch_end(self) -> bool: - """Whether the batch outputs should be stored for later usage.""" - model = self.trainer.lightning_module - if self.trainer.testing: - return is_overridden("test_epoch_end", model) - return is_overridden("validation_epoch_end", model) - def _reset_dl_batch_idx(self, num_dataloaders: int) -> None: self._dl_batch_idx = [0] * num_dataloaders diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index ede1fff787ee7..6f1db6b9c0dfc 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -13,11 +13,12 @@ # limitations under the License. import math from collections import defaultdict, OrderedDict -from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload, Tuple, Union +from typing import Any, DefaultDict, Dict, Generator, List, Optional, Tuple, Union import numpy as np import torch from lightning_utilities.core.apply_func import apply_to_collection +from typing_extensions import overload import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports @@ -29,12 +30,10 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] -_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] class _TrainingEpochLoop(loops._Loop): @@ -77,7 +76,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self.val_loop = loops._EvaluationLoop(verbose=False) self._results = _ResultCollection(training=True) - self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 @@ -131,7 +129,7 @@ def done(self) -> bool: return False - def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE: + def run(self, data_fetcher: AbstractDataFetcher) -> None: self.reset() self.on_run_start(data_fetcher) while not self.done: @@ -142,7 +140,6 @@ def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE: except StopIteration: break self._restarting = False - return self.on_run_end() def reset(self) -> None: """Resets the internal state of the loop for a new run.""" @@ -167,8 +164,6 @@ def reset(self) -> None: # seen per epoch, this is useful for tracking when validation is run multiple times per epoch self.val_loop.epoch_loop.batch_progress.total.reset() - self._outputs = [] - def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: _ = iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching @@ -252,13 +247,6 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: self.batch_progress.increment_completed() - if batch_output and is_overridden("training_epoch_end", self.trainer.lightning_module): - # batch_output may be empty - # automatic: can be empty if all optimizers skip their batches - # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, - # then `advance` doesn't finish and an empty dict is returned - self._outputs.append(batch_output) - # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- @@ -288,10 +276,6 @@ def on_advance_end(self) -> None: if not self._is_training_done and self.trainer.received_sigterm: raise SIGTERMException - def on_run_end(self) -> _OUTPUTS_TYPE: - outputs, self._outputs = self._outputs, [] - return outputs - def teardown(self) -> None: self._results.cpu() self.val_loop.teardown() @@ -361,34 +345,6 @@ def _prepare_outputs_training_batch_end( array = _recursive_unpad(array) return array - @staticmethod - def _prepare_outputs_training_epoch_end( - batch_outputs: _OUTPUTS_TYPE, - lightning_module: "pl.LightningModule", - num_optimizers: int, - ) -> Union[List[List[List[Dict[str, Any]]]], List[List[Dict[str, Any]]], List[Dict[str, Any]]]: - """Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook.""" - # `batch_outputs` (plural) is the same as `epoch_end_output` (singular) - if not batch_outputs: - return [] # type: ignore[return-value] - - # convert optimizer dicts to list - if lightning_module.automatic_optimization: - batch_outputs = apply_to_collection( - batch_outputs, dtype=dict, function=_convert_optim_dict, num_optimizers=num_optimizers - ) - - array = _recursive_pad(batch_outputs) - # squeeze all single-element dimensions - array = array.squeeze() - array = array.tolist() - array = _recursive_unpad(array) - # in case we squeezed from 1-element array to a 0-dim array - array = array if isinstance(array, list) else [array] - # remove residual empty lists - array = [item for item in array if not isinstance(item, list) or len(item)] - return array - def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None: """updates the lr schedulers based on the given interval.""" if interval == "step" and self._should_accumulate(): diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 695871e0c9f77..2a637c3f643b2 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -17,14 +17,12 @@ import pytorch_lightning as pl from pytorch_lightning.loops import _Loop from pytorch_lightning.loops.epoch import _TrainingEpochLoop -from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE from pytorch_lightning.loops.progress import Progress from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException, SIGTERMException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher -from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -77,7 +75,6 @@ def __init__( self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True - self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._data_fetcher: Optional[AbstractDataFetcher] = None @property @@ -241,9 +238,6 @@ def on_advance_start(self) -> None: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False - # reset outputs here instead of in `reset` as they are not accumulated between epochs - self._outputs = [] - if self.trainer.train_dataloader is not None: assert isinstance(self.trainer.train_dataloader, CombinedLoader) _set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed) @@ -274,31 +268,12 @@ def batch_to_device(batch: Any) -> Any: assert self._data_fetcher is not None self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device) with self.trainer.profiler.profile("run_training_epoch"): - self._outputs = self.epoch_loop.run(self._data_fetcher) + self.epoch_loop.run(self._data_fetcher) def on_advance_end(self) -> None: # inform logger the batch loop has finished self.trainer._logger_connector.epoch_end_reached() - # get the model and call model.training_epoch_end - model = self.trainer.lightning_module - if is_overridden("training_epoch_end", model) and self._outputs: - epoch_end_outputs = self.epoch_loop._prepare_outputs_training_epoch_end( - self._outputs, - lightning_module=model, - num_optimizers=len(self.trainer.optimizers), - ) - # run lightning module hook training_epoch_end - # refresh the result for custom logging at the epoch level - epoch_end_outputs = self.trainer._call_lightning_module_hook("training_epoch_end", epoch_end_outputs) - if epoch_end_outputs is not None: - raise MisconfigurationException( - "`training_epoch_end` expects a return of None. " - "HINT: remove the return statement in `training_epoch_end`." - ) - # free memory - self._outputs = [] - self.epoch_progress.increment_processed() # call train epoch end hooks diff --git a/src/pytorch_lightning/loops/optimization/manual_loop.py b/src/pytorch_lightning/loops/optimization/manual_loop.py index e5c7ec4da9364..0a547e0f12eae 100644 --- a/src/pytorch_lightning/loops/optimization/manual_loop.py +++ b/src/pytorch_lightning/loops/optimization/manual_loop.py @@ -49,7 +49,7 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) elif training_step_output is not None: raise MisconfigurationException( "In manual optimization, `training_step` must either return a Tensor, " - "a dict with extras to pass to `training_epoch_end` or have no return." + "a dict with extras to pass to `training_step_end` or have no return." ) if "loss" in extra: diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index ab27c49e01991..ff3a1c8d54d17 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -50,19 +50,13 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - # ----------------------------------- - # verify model has a training step - # ----------------------------------- + # verify minimum training requirements has_training_step = is_overridden("training_step", model) if not has_training_step: raise MisconfigurationException( "No `training_step()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) - - # ----------------------------------- - # verify model has optimizer - # ----------------------------------- has_optimizers = is_overridden("configure_optimizers", model) if not has_optimizers: raise MisconfigurationException( @@ -70,11 +64,11 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) + # verify gradient accumulation setup overridden_optimizer_step = is_overridden("optimizer_step", model) overridden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model) automatic_optimization = model.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() - has_overridden_optimization_functions = overridden_optimizer_step or overridden_optimizer_zero_grad if has_overridden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: rank_zero_warn( @@ -83,13 +77,9 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh " (rather, they are called on every optimization step)." ) - # ----------------------------------- - # verify model for val loop - # ----------------------------------- - + # verify minimum validation requirements has_val_loader = trainer._data_connector._val_dataloader_source.is_defined() has_val_step = is_overridden("validation_step", model) - if has_val_loader and not has_val_step: rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.") if has_val_step and not has_val_loader: @@ -98,11 +88,25 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh category=PossibleUserWarning, ) + # check legacy hooks are not present + if callable(getattr(model, "training_epoch_end", None)): + raise NotImplementedError( + f"Support for `training_epoch_end` has been removed in v2.0.0. `{type(model).__name__}` implements this" + " method. You can use the `on_train_epoch_end` hook instead. If you were using the `outputs` input argument" + ", you can cache them in-memory by saving them as an instance attribute." + " You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520." + ) + if callable(getattr(model, "validation_epoch_end", None)): + raise NotImplementedError( + f"Support for `validation_epoch_end` has been removed in v2.0.0. `{type(model).__name__}` implements this" + " method. You can use the `on_validation_epoch_end` hook instead. If you were using the `outputs` input" + " argument, You can cache them in-memory by saving them as an instance attribute." + " You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520." + ) + def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> None: step_name = "validation_step" if stage == "val" else f"{stage}_step" - trainer_method = "validate" if stage == "val" else stage - has_step = is_overridden(step_name, model) # predict_step is not required to be overridden @@ -112,12 +116,21 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> elif not has_step and not is_overridden("forward", model): raise MisconfigurationException("`Trainer.predict` requires `forward` method to run.") else: - # ----------------------------------- - # verify model has an eval_step - # ----------------------------------- + # verify minimum evaluation requirements if not has_step: + trainer_method = "validate" if stage == "val" else stage raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.") + # check legacy hooks are not present + epoch_end_name = "validation_epoch_end" if stage == "val" else "test_epoch_end" + if callable(getattr(model, epoch_end_name, None)): + raise NotImplementedError( + f"Support for `{epoch_end_name}` has been removed in v2.0.0. `{type(model).__name__}` implements this" + f" method. You can use the `on_{epoch_end_name}` hook instead. If you were using the `outputs` input" + " argument, You can cache them in-memory by saving them as an instance attribute." + " You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520." + ) + def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None: """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 5f7b8660bb0bc..fbedfe88049c5 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -143,15 +143,6 @@ class _LogOptions(TypedDict): "test_step_end": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True ), - "training_epoch_end": _LogOptions( - allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True - ), - "validation_epoch_end": _LogOptions( - allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True - ), - "test_epoch_end": _LogOptions( - allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True - ), "configure_optimizers": None, "train_dataloader": None, "val_dataloader": None, diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ba1720910afe3..7e3e0fabca8b8 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -561,8 +561,7 @@ def validate( Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks - like :meth:`~pytorch_lightning.core.module.LightningModule.validation_step`, - :meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc. + like :meth:`~pytorch_lightning.core.module.LightningModule.validation_step` etc. The length of the list corresponds to the number of validation dataloaders used. """ if model is None: @@ -654,8 +653,7 @@ def test( Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks - like :meth:`~pytorch_lightning.core.module.LightningModule.test_step`, - :meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc. + like :meth:`~pytorch_lightning.core.module.LightningModule.test_step` etc. The length of the list corresponds to the number of test dataloaders used. """ if model is None: diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index c3bb4bbcb9251..0f30c841717ba 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -32,7 +32,6 @@ _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] STEP_OUTPUT = Union[Tensor, Dict[str, Any]] -EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] TRAIN_DATALOADERS = Union[