diff --git a/docs/source-pytorch/accelerators/accelerator_prepare.rst b/docs/source-pytorch/accelerators/accelerator_prepare.rst index f736c57472d2c..f1da6867a0eee 100644 --- a/docs/source-pytorch/accelerators/accelerator_prepare.rst +++ b/docs/source-pytorch/accelerators/accelerator_prepare.rst @@ -105,19 +105,27 @@ Note if you use any built in metrics or custom metrics that use `TorchMetrics None: + def on_validation_epoch_end(self) -> None: # since the training step/validation step and test step are run on the IPU device # we must log the average loss outside the step functions. - self.log("val_acc", torch.stack(outputs).mean(), prog_bar=True) + self.log("val_acc", torch.stack(self.val_outptus).mean(), prog_bar=True) + self.val_outptus.clear() - def test_epoch_end(self, outputs) -> None: - self.log("test_acc", torch.stack(outputs).mean()) + def on_test_epoch_end(self) -> None: + self.log("test_acc", torch.stack(self.test_outputs).mean()) + self.test_outputs.clear() def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @@ -75,9 +79,7 @@ def configure_optimizers(self): if __name__ == "__main__": dm = MNISTDataModule(batch_size=32) - model = LitClassifier() - trainer = Trainer(max_epochs=2, accelerator="ipu", devices=8) trainer.fit(model, datamodule=dm) diff --git a/src/lightning/app/utilities/introspection.py b/src/lightning/app/utilities/introspection.py index 7859ff0b00886..e36aae8e5c73c 100644 --- a/src/lightning/app/utilities/introspection.py +++ b/src/lightning/app/utilities/introspection.py @@ -79,16 +79,13 @@ class LightningModuleVisitor(LightningVisitor): "save_hyperparameters", "test_step", "test_step_end", - "test_epoch_end", "to_onnx", "to_torchscript", "training_step", "training_step_end", - "training_epoch_end", "unfreeze", "validation_step", "validation_step_end", - "validation_epoch_end", } hooks: Set[str] = { diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3013c35a7904f..442e17ed1e555 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -95,6 +95,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deadlock detection / process reconciliation (`PL_RECONCILE_PROCESS=1`) ([#16204](https://github.com/Lightning-AI/lightning/pull/16204)) +- Removed the `{training,validation,test}_epoch_end` hooks which would retain step outputs in memory. Alternative implementations are suggested by implementing their `on_*_epoch_end` hooks instead ([#16520](https://github.com/Lightning-AI/lightning/pull/16520)) + - Removed support for the experimental `PL_FAULT_TOLERANT_TRAINING` environment flag ([#16516](https://github.com/Lightning-AI/lightning/pull/16516), [#16533](https://github.com/Lightning-AI/lightning/pull/16533)) - Removed the deprecated `LightningCLI` arguments ([#16380](https://github.com/Lightning-AI/lightning/pull/16380)) diff --git a/src/lightning/pytorch/callbacks/callback.py b/src/lightning/pytorch/callbacks/callback.py index 432764e036508..8e9dcfce0a586 100644 --- a/src/lightning/pytorch/callbacks/callback.py +++ b/src/lightning/pytorch/callbacks/callback.py @@ -95,10 +95,29 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called when the train epoch ends. - To access all batch outputs at the end of the epoch, either: + To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the + :class:`pytorch_lightning.LightningModule` and access them in this hook: - 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR - 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook. + .. code-block:: python + + class MyLightningModule(L.LightningModule): + def __init__(self): + super().__init__() + self.training_step_outputs = [] + + def training_step(self): + loss = ... + self.training_step_outputs.append(loss) + return loss + + + class MyCallback(L.Callback): + def on_train_epoch_end(self, trainer, pl_module): + # do something with all training_step outputs, for example: + epoch_mean = torch.stack(pl_module.training_step_outputs).mean() + pl_module.log("training_epoch_mean", epoch_mean) + # free up the memory + pl_module.training_step_outputs.clear() """ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 1b2aacbcdee54..501c270fe584d 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -169,10 +169,27 @@ def on_train_epoch_start(self) -> None: def on_train_epoch_end(self) -> None: """Called in the training loop at the very end of the epoch. - To access all batch outputs at the end of the epoch, either: + To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the + :class:`pytorch_lightning.LightningModule` and access them in this hook: - 1. Implement `training_epoch_end` in the LightningModule OR - 2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook + .. code-block:: python + + class MyLightningModule(L.LightningModule): + def __init__(self): + super().__init__() + self.training_step_outputs = [] + + def training_step(self): + loss = ... + self.training_step_outputs.append(loss) + return loss + + def on_train_epoch_end(self): + # do something with all training_step outputs, for example: + epoch_mean = torch.stack(self.training_step_outputs).mean() + self.log("training_epoch_mean", epoch_mean) + # free up the memory + self.training_step_outputs.clear() """ def on_validation_epoch_start(self) -> None: diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 652bf60c4cd5a..7a363463c2511 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -48,13 +48,7 @@ from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -from lightning.pytorch.utilities.types import ( - _METRIC, - EPOCH_OUTPUT, - LRSchedulerPLType, - LRSchedulerTypeUnion, - STEP_OUTPUT, -) +from lightning.pytorch.utilities.types import _METRIC, LRSchedulerPLType, LRSchedulerTypeUnion, STEP_OUTPUT warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -753,10 +747,10 @@ def training_step(self, batch, batch_idx): return {"pred": out} - def training_step_end(self, training_step_outputs): - gpu_0_pred = training_step_outputs[0]["pred"] - gpu_1_pred = training_step_outputs[1]["pred"] - gpu_n_pred = training_step_outputs[n]["pred"] + def training_step_end(self, training_step_output): + gpu_0_pred = training_step_output[0]["pred"] + gpu_1_pred = training_step_output[1]["pred"] + gpu_n_pred = training_step_output[n]["pred"] # this softmax now uses the full batch loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred]) @@ -766,51 +760,11 @@ def training_step_end(self, training_step_outputs): See the :ref:`Multi GPU Training ` guide for more details. """ - def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - """Called at the end of the training epoch with the outputs of all training steps. Use this in case you - need to do something with all the outputs returned by :meth:`training_step`. - - .. code-block:: python - - # the pseudocode for these calls - train_outs = [] - for train_batch in train_data: - out = training_step(train_batch) - train_outs.append(out) - training_epoch_end(train_outs) - - Args: - outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists - have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed. - - Return: - None - - Note: - If this method is not overridden, this won't be called. - - .. code-block:: python - - def training_epoch_end(self, training_step_outputs): - # do something with all training_step outputs - for out in training_step_outputs: - ... - """ - def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. - .. code-block:: python - - # the pseudocode for these calls - val_outs = [] - for val_batch in val_data: - out = validation_step(val_batch) - val_outs.append(out) - validation_epoch_end(val_outs) - Args: batch: The output of your :class:`~torch.utils.data.DataLoader`. batch_idx: The index of this batch. @@ -824,13 +778,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: .. code-block:: python # pseudocode of order - val_outs = [] for val_batch in val_data: out = validation_step(val_batch) if defined("validation_step_end"): out = validation_step_end(out) - val_outs.append(out) - val_outs = validation_epoch_end(val_outs) .. code-block:: python @@ -939,65 +890,12 @@ def validation_step_end(self, val_step_outputs): See the :ref:`Multi GPU Training ` guide for more details. """ - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - """Called at the end of the validation epoch with the outputs of all validation steps. - - .. code-block:: python - - # the pseudocode for these calls - val_outs = [] - for val_batch in val_data: - out = validation_step(val_batch) - val_outs.append(out) - validation_epoch_end(val_outs) - - Args: - outputs: List of outputs you defined in :meth:`validation_step`, or if there - are multiple dataloaders, a list containing a list of outputs for each dataloader. - - Return: - None - - Note: - If you didn't define a :meth:`validation_step`, this won't be called. - - Examples: - With a single dataloader: - - .. code-block:: python - - def validation_epoch_end(self, val_step_outputs): - for out in val_step_outputs: - ... - - With multiple dataloaders, `outputs` will be a list of lists. The outer list contains - one entry per dataloader, while the inner list contains the individual outputs of - each validation step for that dataloader. - - .. code-block:: python - - def validation_epoch_end(self, outputs): - for dataloader_output_result in outputs: - dataloader_outs = dataloader_output_result.dataloader_i_outputs - - self.log("final_metric", final_value) - """ - def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest such as accuracy. - .. code-block:: python - - # the pseudocode for these calls - test_outs = [] - for test_batch in test_data: - out = test_step(test_batch) - test_outs.append(out) - test_epoch_end(test_outs) - Args: batch: The output of your :class:`~torch.utils.data.DataLoader`. batch_idx: The index of this batch. @@ -1117,56 +1015,6 @@ def test_step_end(self, output_results): See the :ref:`Multi GPU Training ` guide for more details. """ - def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - """Called at the end of a test epoch with the output of all test steps. - - .. code-block:: python - - # the pseudocode for these calls - test_outs = [] - for test_batch in test_data: - out = test_step(test_batch) - test_outs.append(out) - test_epoch_end(test_outs) - - Args: - outputs: List of outputs you defined in :meth:`test_step_end`, or if there - are multiple dataloaders, a list containing a list of outputs for each dataloader - - Return: - None - - Note: - If you didn't define a :meth:`test_step`, this won't be called. - - Examples: - With a single dataloader: - - .. code-block:: python - - def test_epoch_end(self, outputs): - # do something with the outputs of all test batches - all_test_preds = test_step_outputs.predictions - - some_result = calc_all_results(all_test_preds) - self.log(some_result) - - With multiple dataloaders, `outputs` will be a list of lists. The outer list contains - one entry per dataloader, while the inner list contains the individual outputs of - each test step for that dataloader. - - .. code-block:: python - - def test_epoch_end(self, outputs): - final_value = 0 - for dataloader_outputs in outputs: - for test_step_out in dataloader_outputs: - # do something - final_value += test_step_out - - self.log("final_metric", final_value) - """ - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: """Step function called during :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. By default, it calls :meth:`~lightning.pytorch.core.module.LightningModule.forward`. Override to add any processing logic. diff --git a/src/lightning/pytorch/demos/boring_classes.py b/src/lightning/pytorch/demos/boring_classes.py index 87ba74d3d00b6..2f960c039cb86 100644 --- a/src/lightning/pytorch/demos/boring_classes.py +++ b/src/lightning/pytorch/demos/boring_classes.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -23,7 +23,7 @@ from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER from lightning.pytorch import LightningDataModule, LightningModule from lightning.pytorch.core.optimizer import LightningOptimizer -from lightning.pytorch.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from lightning.pytorch.utilities.types import STEP_OUTPUT class RandomDictDataset(Dataset): @@ -89,14 +89,14 @@ class TestModel(BoringModel): def training_step(self, ...): ... # do your own thing - training_epoch_end = None # disable hook + training_step_end = None # disable hook or Example:: model = BoringModel() - model.training_epoch_end = None # disable hook + model.training_step_end = None # disable hook """ super().__init__() self.layer = torch.nn.Linear(32, 2) @@ -117,27 +117,15 @@ def step(self, batch: Tensor) -> Tensor: def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: return {"loss": self.step(batch)} - def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT: - return training_step_outputs - - def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - outputs = cast(List[Dict[str, Tensor]], outputs) - torch.stack([x["loss"] for x in outputs]).mean() + def training_step_end(self, training_step_output: STEP_OUTPUT) -> STEP_OUTPUT: + return training_step_output def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: return {"x": self.step(batch)} - def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - outputs = cast(List[Dict[str, Tensor]], outputs) - torch.stack([x["x"] for x in outputs]).mean() - def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: return {"y": self.step(batch)} - def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: - outputs = cast(List[Dict[str, Tensor]], outputs) - torch.stack([x["y"] for x in outputs]).mean() - def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]: optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) diff --git a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py index 75ed1d9338211..65b2065f47cbf 100644 --- a/src/lightning/pytorch/loops/dataloader/evaluation_loop.py +++ b/src/lightning/pytorch/loops/dataloader/evaluation_loop.py @@ -31,7 +31,6 @@ from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher from lightning.pytorch.utilities.rank_zero import rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature -from lightning.pytorch.utilities.types import EPOCH_OUTPUT if _RICH_AVAILABLE: from rich import get_console @@ -51,7 +50,6 @@ def __init__(self, verbose: bool = True) -> None: self.verbose = verbose self._results = _ResultCollection(training=False) - self._outputs: List[EPOCH_OUTPUT] = [] self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False @@ -113,7 +111,6 @@ def reset(self) -> None: """Resets the internal state of the loop.""" self._max_batches = self._get_max_batches() # bookkeeping - self._outputs = [] self._logged_outputs = [] if isinstance(self._max_batches, int): @@ -154,10 +151,7 @@ def batch_to_device(batch: Any) -> Any: kwargs = OrderedDict() if self.num_dataloaders > 1: kwargs["dataloader_idx"] = dataloader_idx - dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) - - # store batch level output per dataloader - self._outputs.append(dl_outputs) + self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) if not self.trainer.sanity_checking: # indicate the loop has run @@ -180,10 +174,7 @@ def on_run_end(self) -> List[_OUT_DICT]: """Runs the ``_on_evaluation_epoch_end`` hook.""" # if `done` returned True before any iterations were done, this won't have been called in `on_advance_end` self.trainer._logger_connector.epoch_end_reached() - - # hook - self._evaluation_epoch_end(self._outputs) - self._outputs = [] # free memory + self.trainer._logger_connector._evaluation_epoch_end() # hook self._on_evaluation_epoch_end() @@ -216,7 +207,6 @@ def teardown(self) -> None: self._data_fetcher.teardown() self._data_fetcher = None self._results.cpu() - self.epoch_loop.teardown() def _get_max_batches(self) -> List[Union[int, float]]: """Returns the max number of batches for each dataloader.""" @@ -279,19 +269,6 @@ def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self.trainer._call_callback_hooks(hook_name, *args, **kwargs) self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs) - def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None: - """Runs ``{validation/test}_epoch_end``""" - self.trainer._logger_connector._evaluation_epoch_end() - - # with a single dataloader don't pass a 2D list - output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = ( - outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs - ) - - # call the model epoch end - hook_name = "test_epoch_end" if self.trainer.testing else "validation_epoch_end" - self.trainer._call_lightning_module_hook(hook_name, output_or_outputs) - def _on_evaluation_epoch_end(self) -> None: """Runs ``on_{validation/test}_epoch_end`` hook.""" hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" diff --git a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py index 6a70c08c52eda..4b8eb66255bfa 100644 --- a/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import OrderedDict -from functools import lru_cache from typing import Any, Optional, Union from lightning.pytorch.loops.loop import _Loop @@ -21,8 +20,7 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import SIGTERMException from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher -from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning.pytorch.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from lightning.pytorch.utilities.types import STEP_OUTPUT class _EvaluationEpochLoop(_Loop): @@ -36,7 +34,6 @@ def __init__(self) -> None: super().__init__() self.batch_progress = BatchProgress() - self._outputs: EPOCH_OUTPUT = [] self._dl_max_batches: Union[int, float] = 0 self._data_fetcher: Optional[AbstractDataFetcher] = None self._dl_batch_idx = [0] @@ -46,9 +43,7 @@ def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" return self.batch_progress.current.completed >= self._dl_max_batches - def run( - self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict - ) -> EPOCH_OUTPUT: + def run(self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None: self.reset() self.on_run_start(data_fetcher, dl_max_batches, kwargs) while not self.done: @@ -58,13 +53,12 @@ def run( except StopIteration: break self._restarting = False - return self.on_run_end() + self.on_run_end() def reset(self) -> None: """Resets the loop's internal state.""" self._dl_max_batches = 0 self._data_fetcher = None - self._outputs = [] if not self.restarting: self.batch_progress.reset_on_run() @@ -154,22 +148,11 @@ def advance( self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx]) self._dl_batch_idx[dataloader_idx] += 1 - # track epoch level outputs - if self._should_track_batch_outputs_for_epoch_end() and output is not None: - self._outputs.append(output) - if not self.batch_progress.is_last_batch and self.trainer.received_sigterm: raise SIGTERMException - def on_run_end(self) -> EPOCH_OUTPUT: - """Returns the outputs of the whole run.""" - outputs, self._outputs = self._outputs, [] # free memory + def on_run_end(self) -> None: self._data_fetcher = None - return outputs - - def teardown(self) -> None: - # in case the model changes - self._should_track_batch_outputs_for_epoch_end.cache_clear() def _num_completed_batches_reached(self) -> bool: epoch_finished_on_completed = self.batch_progress.current.completed == self._dl_max_batches @@ -253,13 +236,5 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde kwargs.move_to_end("batch", last=False) return kwargs - @lru_cache(1) - def _should_track_batch_outputs_for_epoch_end(self) -> bool: - """Whether the batch outputs should be stored for later usage.""" - model = self.trainer.lightning_module - if self.trainer.testing: - return is_overridden("test_epoch_end", model) - return is_overridden("validation_epoch_end", model) - def _reset_dl_batch_idx(self, num_dataloaders: int) -> None: self._dl_batch_idx = [0] * num_dataloaders diff --git a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py index 271e6e1326fcd..8573ebf3f379f 100644 --- a/src/lightning/pytorch/loops/epoch/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/epoch/training_epoch_loop.py @@ -13,7 +13,7 @@ # limitations under the License. import math from collections import OrderedDict -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union import torch @@ -26,12 +26,10 @@ from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher -from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature _BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] -_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] class _TrainingEpochLoop(loops._Loop): @@ -74,7 +72,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None self.val_loop = loops._EvaluationLoop(verbose=False) self._results = _ResultCollection(training=True) - self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() self._batches_that_stepped: int = 0 @@ -128,7 +125,7 @@ def done(self) -> bool: return False - def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE: + def run(self, data_fetcher: AbstractDataFetcher) -> None: self.reset() self.on_run_start(data_fetcher) while not self.done: @@ -139,7 +136,6 @@ def run(self, data_fetcher: AbstractDataFetcher) -> _OUTPUTS_TYPE: except StopIteration: break self._restarting = False - return self.on_run_end() def reset(self) -> None: """Resets the internal state of the loop for a new run.""" @@ -164,8 +160,6 @@ def reset(self) -> None: # seen per epoch, this is useful for tracking when validation is run multiple times per epoch self.val_loop.epoch_loop.batch_progress.total.reset() - self._outputs = [] - def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: _ = iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching @@ -240,13 +234,6 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: self.batch_progress.increment_completed() - if batch_output and is_overridden("training_epoch_end", self.trainer.lightning_module): - # batch_output may be empty - # automatic: can be empty if all optimizers skip their batches - # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, - # then `advance` doesn't finish and an empty dict is returned - self._outputs.append(batch_output) - # ----------------------------------------- # SAVE METRICS TO LOGGERS AND PROGRESS_BAR # ----------------------------------------- @@ -276,10 +263,6 @@ def on_advance_end(self) -> None: if not self._is_training_done and self.trainer.received_sigterm: raise SIGTERMException - def on_run_end(self) -> _OUTPUTS_TYPE: - outputs, self._outputs = self._outputs, [] - return outputs - def teardown(self) -> None: self._results.cpu() self.val_loop.teardown() diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 3c18daab56f89..0826b0f81a003 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -17,14 +17,12 @@ import lightning.pytorch as pl from lightning.pytorch.loops import _Loop from lightning.pytorch.loops.epoch import _TrainingEpochLoop -from lightning.pytorch.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE from lightning.pytorch.loops.progress import Progress from lightning.pytorch.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection from lightning.pytorch.trainer.supporters import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException from lightning.pytorch.utilities.fetching import AbstractDataFetcher, DataFetcher, DataLoaderIterDataFetcher -from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature @@ -76,7 +74,6 @@ def __init__( self.epoch_progress = Progress() self._is_fresh_start_epoch: bool = True - self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._data_fetcher: Optional[AbstractDataFetcher] = None @property @@ -240,9 +237,6 @@ def on_advance_start(self) -> None: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False - # reset outputs here instead of in `reset` as they are not accumulated between epochs - self._outputs = [] - if self.trainer.train_dataloader is not None: assert isinstance(self.trainer.train_dataloader, CombinedLoader) _set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed) @@ -273,24 +267,12 @@ def batch_to_device(batch: Any) -> Any: assert self._data_fetcher is not None self._data_fetcher.setup(dataloader, batch_to_device=batch_to_device) with self.trainer.profiler.profile("run_training_epoch"): - self._outputs = self.epoch_loop.run(self._data_fetcher) + self.epoch_loop.run(self._data_fetcher) def on_advance_end(self) -> None: # inform logger the batch loop has finished self.trainer._logger_connector.epoch_end_reached() - # get the model and call model.training_epoch_end - model = self.trainer.lightning_module - if is_overridden("training_epoch_end", model) and self._outputs: - return_value = self.trainer._call_lightning_module_hook("training_epoch_end", self._outputs) - if return_value is not None: - raise MisconfigurationException( - "`training_epoch_end` expects a return of None. " - "HINT: remove the return statement in `training_epoch_end`." - ) - # free memory - self._outputs = [] - self.epoch_progress.increment_processed() # call train epoch end hooks diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index 2dd2491ef887b..dbc546e5cce4c 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -48,7 +48,7 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) elif training_step_output is not None: raise MisconfigurationException( "In manual optimization, `training_step` must either return a Tensor, " - "a dict with extras to pass to `training_epoch_end` or have no return." + "a dict with extras to pass to `training_step_end` or have no return." ) if "loss" in extra: diff --git a/src/lightning/pytorch/trainer/configuration_validator.py b/src/lightning/pytorch/trainer/configuration_validator.py index d7c2dca76aeb5..f0565132ce7c3 100644 --- a/src/lightning/pytorch/trainer/configuration_validator.py +++ b/src/lightning/pytorch/trainer/configuration_validator.py @@ -50,19 +50,13 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: - # ----------------------------------- - # verify model has a training step - # ----------------------------------- + # verify minimum training requirements has_training_step = is_overridden("training_step", model) if not has_training_step: raise MisconfigurationException( "No `training_step()` method defined. Lightning `Trainer` expects as minimum a" " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) - - # ----------------------------------- - # verify model has optimizer - # ----------------------------------- has_optimizers = is_overridden("configure_optimizers", model) if not has_optimizers: raise MisconfigurationException( @@ -70,11 +64,11 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh " `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined." ) + # verify gradient accumulation setup overridden_optimizer_step = is_overridden("optimizer_step", model) overridden_optimizer_zero_grad = is_overridden("optimizer_zero_grad", model) automatic_optimization = model.automatic_optimization going_to_accumulate_grad_batches = trainer.accumulation_scheduler.going_to_accumulate_grad_batches() - has_overridden_optimization_functions = overridden_optimizer_step or overridden_optimizer_zero_grad if has_overridden_optimization_functions and going_to_accumulate_grad_batches and automatic_optimization: rank_zero_warn( @@ -83,13 +77,9 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh " (rather, they are called on every optimization step)." ) - # ----------------------------------- - # verify model for val loop - # ----------------------------------- - + # verify minimum validation requirements has_val_loader = trainer._data_connector._val_dataloader_source.is_defined() has_val_step = is_overridden("validation_step", model) - if has_val_loader and not has_val_step: rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.") if has_val_step and not has_val_loader: @@ -98,11 +88,25 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh category=PossibleUserWarning, ) + # check legacy hooks are not present + if callable(getattr(model, "training_epoch_end", None)): + raise NotImplementedError( + f"Support for `training_epoch_end` has been removed in v2.0.0. `{type(model).__name__}` implements this" + " method. You can use the `on_train_epoch_end` hook instead. To access outputs, save them in-memory as" + " instance attributes." + " You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520." + ) + if callable(getattr(model, "validation_epoch_end", None)): + raise NotImplementedError( + f"Support for `validation_epoch_end` has been removed in v2.0.0. `{type(model).__name__}` implements this" + " method. You can use the `on_validation_epoch_end` hook instead. To access outputs, save them in-memory as" + " instance attributes." + " You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520." + ) + def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> None: step_name = "validation_step" if stage == "val" else f"{stage}_step" - trainer_method = "validate" if stage == "val" else stage - has_step = is_overridden(step_name, model) # predict_step is not required to be overridden @@ -112,12 +116,21 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) -> elif not has_step and not is_overridden("forward", model): raise MisconfigurationException("`Trainer.predict` requires `forward` method to run.") else: - # ----------------------------------- - # verify model has an eval_step - # ----------------------------------- + # verify minimum evaluation requirements if not has_step: + trainer_method = "validate" if stage == "val" else stage raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.") + # check legacy hooks are not present + epoch_end_name = "validation_epoch_end" if stage == "val" else "test_epoch_end" + if callable(getattr(model, epoch_end_name, None)): + raise NotImplementedError( + f"Support for `{epoch_end_name}` has been removed in v2.0.0. `{type(model).__name__}` implements this" + f" method. You can use the `on_{epoch_end_name}` hook instead. To access outputs, save them in-memory" + " as instance attributes." + " You can find migration examples in https://github.com/Lightning-AI/lightning/pull/16520." + ) + def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None: """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index 05dc36dae7da5..bbb7d0020048c 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -143,15 +143,6 @@ class _LogOptions(TypedDict): "test_step_end": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True ), - "training_epoch_end": _LogOptions( - allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True - ), - "validation_epoch_end": _LogOptions( - allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True - ), - "test_epoch_end": _LogOptions( - allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True - ), "configure_optimizers": None, "train_dataloader": None, "val_dataloader": None, diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f22bfed2dbd7b..5c3d26435027f 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -561,8 +561,7 @@ def validate( Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks - like :meth:`~lightning.pytorch.core.module.LightningModule.validation_step`, - :meth:`~lightning.pytorch.core.module.LightningModule.validation_epoch_end`, etc. + like :meth:`~lightning.pytorch.LightningModule.validation_step` etc. The length of the list corresponds to the number of validation dataloaders used. """ if model is None: @@ -654,8 +653,7 @@ def test( Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks - like :meth:`~lightning.pytorch.core.module.LightningModule.test_step`, - :meth:`~lightning.pytorch.core.module.LightningModule.test_epoch_end`, etc. + like :meth:`~lightning.pytorch.LightningModule.test_step` etc. The length of the list corresponds to the number of test dataloaders used. """ if model is None: diff --git a/src/lightning/pytorch/utilities/types.py b/src/lightning/pytorch/utilities/types.py index e42cca87e78a7..3f45998a62ced 100644 --- a/src/lightning/pytorch/utilities/types.py +++ b/src/lightning/pytorch/utilities/types.py @@ -31,7 +31,6 @@ _NUMBER = Union[int, float] _METRIC = Union[Metric, Tensor, _NUMBER] STEP_OUTPUT = Union[Tensor, Dict[str, Any]] -EPOCH_OUTPUT = List[STEP_OUTPUT] _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] TRAIN_DATALOADERS = Union[ diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 478c050e5de18..6ad3e197916f6 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -47,15 +47,6 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): return self.step(batch) - def training_epoch_end(self, outputs) -> None: - pass - - def validation_epoch_end(self, outputs) -> None: - pass - - def test_epoch_end(self, outputs) -> None: - pass - class IPUClassificationModel(ClassificationModel): def training_step(self, batch, batch_idx): diff --git a/tests/tests_pytorch/accelerators/test_tpu.py b/tests/tests_pytorch/accelerators/test_tpu.py index 11defef83ae50..ff1c360324f51 100644 --- a/tests/tests_pytorch/accelerators/test_tpu.py +++ b/tests/tests_pytorch/accelerators/test_tpu.py @@ -155,7 +155,6 @@ def on_train_end(self): model = ManualOptimizationModel() model_copy = deepcopy(model) model.training_step_end = None - model.training_epoch_end = None trainer = Trainer( max_epochs=1, diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index fb44d11da6070..9605bae1b8441 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -136,8 +136,6 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): return model = CustomModel() - model.validation_epoch_end = None - model.test_epoch_end = None # check the sanity dataloaders num_sanity_val_steps = 4 diff --git a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py index abd24d0c899e4..5fe13a259657e 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py +++ b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py @@ -41,9 +41,6 @@ def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: assert "x" in outputs - def training_epoch_end(self, outputs) -> None: - assert len(outputs) == self.trainer.num_training_batches - model = TestModel() trainer = Trainer( @@ -59,22 +56,3 @@ def training_epoch_end(self, outputs) -> None: assert any(isinstance(c, CB) for c in trainer.callbacks) trainer.fit(model) - - -def test_free_memory_on_eval_outputs(tmpdir): - class CB(Callback): - def on_train_epoch_end(self, trainer, pl_module): - assert not trainer._evaluation_loop._outputs - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - enable_model_summary=False, - ) - - trainer.fit(model) diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index 38a7caea56f39..391796685a70a 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -248,7 +248,6 @@ def configure_optimizers(self): return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] model = CustomBoringModel() - model.training_epoch_end = None lr_monitor = LearningRateMonitor(logging_interval=logging_interval) log_every_n_steps = 2 @@ -306,7 +305,6 @@ def configure_optimizers(self): return [optimizer1, optimizer2] model = CustomBoringModel() - model.training_epoch_end = None lr_monitor = LearningRateMonitor(logging_interval=logging_interval) log_every_n_steps = 2 @@ -563,7 +561,6 @@ def finetune_function(self, pl_module, epoch: int, optimizer): enable_checkpointing=False, ) model = TestModel() - model.training_epoch_end = None trainer.fit(model) expected = [0.1, 0.1, 0.1, 0.1, 0.1] diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py index 40f418342ce45..26059f9d4c531 100644 --- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py @@ -96,7 +96,7 @@ def training_step(self, batch, batch_idx): self.log("my_loss", batch_idx * (1 + local_rank), on_epoch=True) return super().training_step(batch, batch_idx) - def training_epoch_end(self, outputs) -> None: + def on_train_epoch_end(self): local_rank = int(os.getenv("LOCAL_RANK")) if self.trainer.is_global_zero: self.log("my_loss_2", (1 + local_rank), on_epoch=True, rank_zero_only=True) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 18bad6e5a6d3a..9e82120f05d68 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -59,9 +59,8 @@ def training_step(self, batch, batch_idx): self.log("early_stop_on", out["loss"]) return out - def validation_epoch_end(self, outputs): - outs = torch.stack([x["x"] for x in outputs]).mean() - self.log("val_acc", outs) + def on_validation_epoch_end(self): + self.log("val_acc", torch.tensor(1.23)) def mock_training_epoch_loop(trainer): @@ -214,9 +213,8 @@ def validation_step(self, batch, batch_idx): self.log("val_log", log_value) return super().validation_step(batch, batch_idx) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self.val_loop_count += 1 - super().validation_epoch_end(outputs) self.scores.append(self.trainer.logged_metrics[monitor]) def configure_optimizers(self): @@ -829,7 +827,7 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode): monitor += [5, 7, 8] if mode == "max" else [8, 7, 5] class CurrentModel(LogInTwoMethods): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): val_loss = monitor[self.current_epoch] self.log("abc", val_loss) @@ -863,7 +861,6 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss) model = ExtendedBoringModel() - model.validation_epoch_end = None trainer_kwargs = { "max_epochs": 1, "limit_train_batches": 2, @@ -901,9 +898,6 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss) return {"val_loss": loss} - def validation_epoch_end(self, *_): - ... - def assert_trainer_init(trainer): assert trainer.global_step == 0 assert trainer.current_epoch == 0 diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index 8fb3b7f8cb3e4..aeb8c830cd881 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -41,7 +41,6 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss, on_epoch=True, prog_bar=True) model = ExtendedBoringModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 5b51a645fbce4..3e18e0fcecf23 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -163,10 +163,8 @@ def test_train_loop_only(tmpdir): model.validation_step = None model.validation_step_end = None - model.validation_epoch_end = None model.test_step = None model.test_step_end = None - model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) @@ -185,7 +183,6 @@ def test_train_val_loop_only(tmpdir): model.validation_step = None model.validation_step_end = None - model.validation_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, enable_model_summary=False) @@ -278,10 +275,8 @@ def train_dataloader(self): model.validation_step = None model.validation_step_end = None - model.validation_epoch_end = None model.test_step = None model.test_step_end = None - model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=3, limit_train_batches=2, reload_dataloaders_every_n_epochs=2) trainer.fit(model, dm) diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 60e03631701a9..f94c8c31ac8ab 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -177,7 +177,6 @@ def configure_optimizers(self): return [optimizer_1, optimizer_2] model = TestModel() - model.training_epoch_end = None trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=0) trainer.fit(model) @@ -280,7 +279,6 @@ def configure_optimizers(self): return [optimizer_1, optimizer_2, optimizer_3] model = TestModel() - model.training_epoch_end = None trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8) trainer.fit(model) diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index 2d45b92f2460d..0161df5e7ef7a 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -109,7 +109,6 @@ def configure_optimizers(self): model = TestModel() model.training_step_end = None - model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=1, max_epochs=1, enable_model_summary=False ) @@ -165,9 +164,6 @@ def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): """Test overriding zero_grad works in automatic_optimization.""" class TestModel(BoringModel): - def training_epoch_end(self, outputs): - ... - def optimizer_zero_grad(self, epoch, batch_idx, optimizer): if batch_idx % 2 == 0: optimizer.zero_grad() @@ -191,9 +187,6 @@ def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir): """Test overriding step works in automatic_optimization.""" class TestModel(BoringModel): - def training_epoch_end(self, outputs): - ... - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, **_): assert isinstance(optimizer_closure, Closure) # zero_grad is called inside the closure diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index a346fa09c89cf..b5a4b588881c2 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -24,11 +24,9 @@ def __init__(self, weights=None): self.training_step_called = False self.training_step_end_called = False - self.training_epoch_end_called = False self.validation_step_called = False self.validation_step_end_called = False - self.validation_epoch_end_called = False self.assert_backward = True @@ -74,18 +72,6 @@ def validation_step_end(self, val_step_output): return val_step_output - def validation_epoch_end(self, outputs): - assert len(outputs) == self.trainer.num_val_batches[0] - - for i, out in enumerate(outputs): - assert out["log"]["log_acc1"] >= 12 + i - - self.validation_epoch_end_called = True - - result = outputs[-1] - result["val_epoch_end"] = torch.tensor(1233) - return result - # ----------------------------- # DATA # ----------------------------- diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 4b5ba093008d8..380058e752387 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -97,13 +97,11 @@ def training_step(self, batch, batch_idx): self.log("train_some_val", loss) return {"loss": loss} - def validation_epoch_end(self, outputs) -> None: - avg_val_loss = torch.stack([x["x"] for x in outputs]).mean() - self.log_dict({"early_stop_on": avg_val_loss, "val_loss": avg_val_loss**0.5}) + def on_validation_epoch_end(self): + self.log_dict({"early_stop_on": torch.tensor(1), "val_loss": torch.tensor(0.5)}) - def test_epoch_end(self, outputs) -> None: - avg_test_loss = torch.stack([x["y"] for x in outputs]).mean() - self.log("test_loss", avg_test_loss) + def on_test_epoch_end(self): + self.log("test_loss", torch.tensor(2)) class StoreHistoryLogger(logger_class): def __init__(self, *args, **kwargs) -> None: diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index e1f390e6de3a4..ba82adac67bdf 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -146,11 +146,11 @@ def log_metrics(self, metrics, step): super().log_metrics(metrics, step) class CustomModel(BoringModel): - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): self.logger.logged_step += 1 self.log_dict({"step": self.logger.logged_step, "train_acc": self.logger.logged_step / 10}) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self.logger.logged_step += 1 self.log_dict({"step": self.logger.logged_step, "val_acc": self.logger.logged_step / 10}) diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index d215df8789352..e4e741eaa0a81 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -208,7 +208,7 @@ def test_neptune_log_metrics_on_trained_model(self, neptune): """Verify that trained models do log data.""" # given class LoggingModel(BoringModel): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self.log("some/key", 42) # and diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py index ff8f20935dec8..1274ce6d9a304 100644 --- a/tests/tests_pytorch/loggers/test_tensorboard.py +++ b/tests/tests_pytorch/loggers/test_tensorboard.py @@ -258,7 +258,6 @@ def training_step(self, *args): return super().training_step(*args) model = TestModel() - model.training_epoch_end = None logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False) trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py b/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py index e4cdc7a435dd0..d60f8088066b8 100644 --- a/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_optimizer_loop.py @@ -59,7 +59,3 @@ def training_step(self, batch, batch_idx): with pytest.raises(MisconfigurationException, match=match): trainer.fit(model) - - -class CustomException(Exception): - pass diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 837b5807fe1b3..293bde5c08d92 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -20,7 +20,6 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.utilities.model_helpers import is_overridden from tests_pytorch.helpers.runif import RunIf @@ -179,32 +178,3 @@ def validation_step(self, batch, batch_idx): enable_model_summary=False, ) trainer.fit(BoringLargeBatchModel()) - - -def test_evaluation_loop_doesnt_store_outputs_if_epoch_end_not_overridden(tmpdir): - did_assert = False - - class TestModel(BoringModel): - def on_test_batch_end(self, outputs, *_): - # check `test_step` returns something - assert outputs is not None - - model = TestModel() - model.test_epoch_end = None - assert not is_overridden("test_epoch_end", model) - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=3) - loop = trainer.test_loop.epoch_loop - original_advance = loop.advance - - def assert_on_advance_end(*args, **kwargs): - original_advance(*args, **kwargs) - # should be empty - assert not loop._outputs - # sanity check - nonlocal did_assert - did_assert = True - - loop.advance = assert_on_advance_end - trainer.test(model) - assert did_assert diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index c91c3d159e8e6..72d533a4b1d7f 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -45,7 +45,6 @@ def backward(self, loss): model = TestModel() model.validation_step_end = None - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -60,7 +59,6 @@ def backward(self, loss): # make sure correct steps were called assert model.validation_step_called assert not model.validation_step_end_called - assert not model.validation_epoch_end_called # simulate training manually trainer.state.stage = RunningStage.TRAINING @@ -104,7 +102,6 @@ def backward(self, loss): return LightningModule.backward(self, loss) model = TestModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -119,7 +116,6 @@ def backward(self, loss): # make sure correct steps were called assert model.validation_step_called assert model.validation_step_end_called - assert not model.validation_epoch_end_called trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected @@ -155,16 +151,6 @@ def validation_step(self, batch, batch_idx): self.out_b = out return out - def validation_epoch_end(self, outputs): - self.validation_epoch_end_called = True - assert len(outputs) == 2 - - out_a = outputs[0] - out_b = outputs[1] - - assert out_a == self.out_a - assert out_b == self.out_b - def backward(self, loss): return LightningModule.backward(self, loss) @@ -185,7 +171,6 @@ def backward(self, loss): # make sure correct steps were called assert model.validation_step_called assert not model.validation_step_end_called - assert model.validation_epoch_end_called def test__validation_step__step_end__epoch_end__flow(tmpdir): @@ -214,16 +199,6 @@ def validation_step_end(self, out): assert self.last_out == out return out - def validation_epoch_end(self, outputs): - self.validation_epoch_end_called = True - assert len(outputs) == 2 - - out_a = outputs[0] - out_b = outputs[1] - - assert out_a == self.out_a - assert out_b == self.out_b - def backward(self, loss): return LightningModule.backward(self, loss) @@ -243,4 +218,3 @@ def backward(self, loss): # make sure correct steps were called assert model.validation_step_called assert model.validation_step_end_called - assert model.validation_epoch_end_called diff --git a/tests/tests_pytorch/loops/test_flow_warnings.py b/tests/tests_pytorch/loops/test_flow_warnings.py index 33aa66d511663..b1c56f52de08e 100644 --- a/tests/tests_pytorch/loops/test_flow_warnings.py +++ b/tests/tests_pytorch/loops/test_flow_warnings.py @@ -27,7 +27,6 @@ def test_no_depre_without_epoch_end(tmpdir): """Tests that only training_step can be used.""" model = TestModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index b7c889091abe0..a1a82a5345c55 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -216,7 +216,6 @@ def val_dataloader(self): return [super(ValidationModel, self).val_dataloader() for _ in range(n_dataloaders)] model = ValidationModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -279,7 +278,6 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) model = TestModel() - model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -446,7 +444,6 @@ def train_dataloader(self): return DataLoader(RandomDataset(32, n_batches)) model = TestModel() - model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 5d94cd7f65df0..2e029ed75875d 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -42,11 +42,6 @@ def on_train_batch_end(self, outputs, batch, batch_idx): HookedModel._check_output(outputs) super().on_train_batch_end(outputs, batch, batch_idx) - def training_epoch_end(self, outputs): - assert len(outputs) == 2 - [HookedModel._check_output(output) for output in outputs] - super().training_epoch_end(outputs) - model = HookedModel() # fit model diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py index e4f877371b454..1b6f0acc59fb0 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py @@ -49,7 +49,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert not model.training_step_end_called - assert not model.training_epoch_end_called def test__training_step__tr_step_end__flow_dict(tmpdir): @@ -88,7 +87,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert model.training_step_end_called - assert not model.training_epoch_end_called def test__training_step__epoch_end__flow_dict(tmpdir): @@ -103,19 +101,6 @@ def training_step(self, batch, batch_idx): out = {"loss": acc, "random_things": [1, "a", torch.tensor(2)], "batch_idx": batch_idx} return out - def training_epoch_end(self, outputs): - self.training_epoch_end_called = True - - # verify we saw the current num of batches - assert len(outputs) == 2 - assert len({id(output) for output in outputs}) == 2 - assert [output["batch_idx"] for output in outputs] == [0, 1] - - for b in outputs: - assert isinstance(b, dict) - assert self.count_num_graphs(b) == 0 - assert {"random_things", "loss", "batch_idx"} == set(b.keys()) - def backward(self, loss): return LightningModule.backward(self, loss) @@ -135,7 +120,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert not model.training_step_end_called - assert model.training_epoch_end_called def test__training_step__step_end__epoch_end__flow_dict(tmpdir): @@ -156,19 +140,6 @@ def training_step_end(self, tr_step_output): self.training_step_end_called = True return tr_step_output - def training_epoch_end(self, outputs): - self.training_epoch_end_called = True - - # verify we saw the current num of batches - assert len(outputs) == 2 - assert len({id(output) for output in outputs}) == 2 - assert [output["batch_idx"] for output in outputs] == [0, 1] - - for b in outputs: - assert isinstance(b, dict) - assert self.count_num_graphs(b) == 0 - assert {"random_things", "loss", "batch_idx"} == set(b.keys()) - def backward(self, loss): return LightningModule.backward(self, loss) @@ -188,4 +159,3 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert model.training_step_end_called - assert model.training_epoch_end_called diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index 31e28a2e06669..54976f3d877f9 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -54,7 +54,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert not model.training_step_end_called - assert not model.training_epoch_end_called def test__training_step__tr_step_end__flow_scalar(tmpdir): @@ -93,7 +92,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert model.training_step_end_called - assert not model.training_epoch_end_called def test__training_step__epoch_end__flow_scalar(tmpdir): @@ -107,18 +105,6 @@ def training_step(self, batch, batch_idx): self.training_step_called = True return acc - def training_epoch_end(self, outputs): - self.training_epoch_end_called = True - - # verify we saw the current num of batches - assert len(outputs) == 2 - - for b in outputs: - # time = 1 - assert len(b) == 1 - assert "loss" in b - assert isinstance(b, dict) - def backward(self, loss): return LightningModule.backward(self, loss) @@ -138,7 +124,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert not model.training_step_end_called - assert model.training_epoch_end_called # assert epoch end metrics were added assert len(trainer.callback_metrics) == 0 @@ -159,7 +144,7 @@ def backward(self, loss): def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): - """Checks train_step + training_step_end + training_epoch_end (all with scalar return from train_step).""" + """Checks train_step + training_step_end (all with scalar return from train_step).""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -175,18 +160,6 @@ def training_step_end(self, tr_step_output): self.training_step_end_called = True return tr_step_output - def training_epoch_end(self, outputs): - self.training_epoch_end_called = True - - # verify we saw the current num of batches - assert len(outputs) == 2 - - for b in outputs: - # time = 1 - assert len(b) == 1 - assert "loss" in b - assert isinstance(b, dict) - def backward(self, loss): return LightningModule.backward(self, loss) @@ -206,7 +179,6 @@ def backward(self, loss): # make sure correct steps were called assert model.training_step_called assert model.training_step_end_called - assert model.training_epoch_end_called # assert epoch end metrics were added assert len(trainer.callback_metrics) == 0 @@ -236,15 +208,9 @@ def training_step(self, batch): loss = self.step(batch[0]) self.log("a", loss, on_step=True, on_epoch=True) - def training_epoch_end(self, outputs) -> None: - assert len(outputs) == 0, outputs - def validation_step(self, batch, batch_idx): self.validation_step_called = True - def validation_epoch_end(self, outputs): - assert len(outputs) == 0, outputs - model = TestModel() trainer_args = dict(default_root_dir=tmpdir, fast_dev_run=2) trainer = Trainer(**trainer_args) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 47dce9e969b37..eea57aac96333 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -67,7 +67,7 @@ def on_before_zero_grad(self, optimizer): assert 0 == model.on_before_zero_grad_called -def test_training_epoch_end_metrics_collection(tmpdir): +def test_on_train_epoch_end_metrics_collection(tmpdir): """Test that progress bar metrics also get collected at the end of an epoch.""" num_epochs = 3 @@ -77,7 +77,7 @@ def training_step(self, *args, **kwargs): self.log_dict({"step_metric": torch.tensor(-1), "shared_metric": 100}, logger=False, prog_bar=True) return output - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): epoch = self.current_epoch # both scalar tensors and Python numbers are accepted self.log_dict( @@ -99,40 +99,6 @@ def training_epoch_end(self, outputs): assert metrics[f"epoch_metric_{i}"] == i -def test_training_epoch_end_metrics_collection_on_override(tmpdir): - """Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch.""" - - class OverriddenModel(BoringModel): - def __init__(self): - super().__init__() - self.len_outputs = 0 - - def on_train_epoch_start(self): - self.num_train_batches = 0 - - def training_epoch_end(self, outputs): - self.len_outputs = len(outputs) - - def on_train_batch_end(self, outputs, batch, batch_idx): - self.num_train_batches += 1 - - class NotOverriddenModel(BoringModel): - def on_train_epoch_start(self): - self.num_train_batches = 0 - - def on_train_batch_end(self, outputs, batch, batch_idx): - self.num_train_batches += 1 - - overridden_model = OverriddenModel() - not_overridden_model = NotOverriddenModel() - not_overridden_model.training_epoch_end = None - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, overfit_batches=2) - - trainer.fit(overridden_model) - assert overridden_model.len_outputs == overridden_model.num_train_batches - - @pytest.mark.parametrize( "accelerator,expected_device_str", [ @@ -214,7 +180,6 @@ def train_dataloader(self): model = TestModel() model.validation_step = None - model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=2, @@ -286,14 +251,6 @@ def call(hook, fn, *args, **kwargs): update_wrapper(partial_h, attr) setattr(self, h, partial_h) - def validation_epoch_end(self, *args, **kwargs): - # `BoringModel` does not have a return for `validation_step_end` so this would fail - pass - - def test_epoch_end(self, *args, **kwargs): - # `BoringModel` does not have a return for `test_step_end` so this would fail - pass - def _train_batch(self, *args, **kwargs): if self.automatic_optimization: return self._auto_train_batch(*args, **kwargs) @@ -391,12 +348,10 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k @staticmethod def _eval_epoch(fn, trainer, model, batches, key, device=torch.device("cpu")): - outputs = {key: ANY} return [ dict(name=f"Callback.on_{fn}_epoch_start", args=(trainer, model)), dict(name=f"on_{fn}_epoch_start"), *HookedModel._eval_batch(fn, trainer, model, batches, key, device=device), - dict(name=f"{fn}_epoch_end", args=([outputs] * batches,)), dict(name=f"Callback.on_{fn}_epoch_end", args=(trainer, model)), dict(name=f"on_{fn}_epoch_end"), ] @@ -546,7 +501,6 @@ def training_step(self, batch, batch_idx): dict(name="on_validation_end"), dict(name="train", args=(True,)), dict(name="on_validation_model_train"), - dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" # `ModelCheckpoint.save_checkpoint` is called here @@ -625,7 +579,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), *model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0), - dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" # `ModelCheckpoint.save_checkpoint` is called here @@ -705,7 +658,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), *model._train_batch(trainer, model, steps_after_reload, current_batch=1), - dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" # `ModelCheckpoint.save_checkpoint` is called here @@ -791,7 +743,6 @@ def test_trainer_model_hook_system_predict(tmpdir): dict(name="Callback.on_predict_epoch_start", args=(trainer, model)), dict(name="on_predict_epoch_start"), *model._predict_batch(trainer, model, batches), - # TODO: `predict_epoch_end` dict(name="Callback.on_predict_epoch_end", args=(trainer, model, [[ANY] * batches])), dict(name="on_predict_epoch_end", args=([[ANY] * batches],)), dict(name="Callback.on_predict_end", args=(trainer, model)), @@ -837,7 +788,6 @@ def predict_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] model = CustomBoringModel() - model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=5) diff --git a/tests/tests_pytorch/plugins/test_double_plugin.py b/tests/tests_pytorch/plugins/test_double_plugin.py index f7daa6da4e7b1..9c93f09cad221 100644 --- a/tests/tests_pytorch/plugins/test_double_plugin.py +++ b/tests/tests_pytorch/plugins/test_double_plugin.py @@ -45,9 +45,8 @@ def training_step(self, batch, batch_idx): assert float_data.dtype == torch.float64 return super().training_step(float_data, batch_idx) - def training_epoch_end(self, outputs) -> None: + def on_train_epoch_end(self): assert torch.tensor([0.0]).dtype == torch.float32 - return super().training_epoch_end(outputs) def validation_step(self, batch, batch_idx): assert batch.dtype == torch.float64 diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index a3174a6995656..15d9c06174d09 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -644,7 +644,6 @@ def test_deepspeed_multigpu_stage_3(tmpdir): def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config): """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModelManualOptim() - model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, strategy=DeepSpeedStrategy(stage=3), diff --git a/tests/tests_pytorch/strategies/test_dp.py b/tests/tests_pytorch/strategies/test_dp.py index f078bca00e7c2..c1216a2613ff9 100644 --- a/tests/tests_pytorch/strategies/test_dp.py +++ b/tests/tests_pytorch/strategies/test_dp.py @@ -91,6 +91,12 @@ def test_multi_gpu_model_dp(tmpdir): class ReductionTestModel(BoringModel): + def __init__(self): + super().__init__() + self.train_outputs = [] + self.val_outputs = [] + self.test_outputs = [] + def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2) @@ -123,17 +129,32 @@ def test_step(self, batch, batch_idx): self.add_outputs(output, batch.device) return output - def training_epoch_end(self, outputs): - assert outputs[0]["loss"].shape == torch.Size([]) - self._assert_extra_outputs(outputs) + def training_step_end(self, training_step_output): + # the strategy does this automatically, but since we want to store these in memory, we need to manually do it + # so that we can append the reduced value and not the per-rank value + training_step_output["loss"] = self.trainer.strategy.reduce(training_step_output["loss"]) + self.train_outputs.append(training_step_output) + # return this or the DP strategy will reduce again + return training_step_output + + def validation_step_end(self, validation_step_output): + self.val_outputs.append(validation_step_output) + # returning a value is not necessary because there's no modification + + def test_step_end(self, test_step_output): + self.test_outputs.append(test_step_output) + + def on_train_epoch_end(self): + assert self.train_outputs[0]["loss"].shape == torch.Size([]) + self._assert_extra_outputs(self.train_outputs) - def validation_epoch_end(self, outputs): - assert outputs[0]["x"].shape == torch.Size([2]) - self._assert_extra_outputs(outputs) + def on_validation_epoch_end(self): + assert self.val_outputs[0]["x"].shape == torch.Size([2]) + self._assert_extra_outputs(self.val_outputs) - def test_epoch_end(self, outputs): - assert outputs[0]["y"].shape == torch.Size([2]) - self._assert_extra_outputs(outputs) + def on_test_epoch_end(self): + assert self.test_outputs[0]["y"].shape == torch.Size([2]) + self._assert_extra_outputs(self.test_outputs) def _assert_extra_outputs(self, outputs): out = outputs[0]["reduce_int"] @@ -149,9 +170,6 @@ def _assert_extra_outputs(self, outputs): def test_dp_training_step_dict(tmpdir): """This test verifies that dp properly reduces dictionaries.""" model = ReductionTestModel() - model.training_step_end = None - model.validation_step_end = None - model.test_step_end = None trainer = pl.Trainer( default_root_dir=tmpdir, diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index eef6c8aeeb57d..e183a9252a5d4 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -90,7 +90,6 @@ def test_dataloader(self): return [self.create_dataset()] * self._numbers_test_dataloaders model = TestModel(2, mode) - model.test_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py index 52bd4a3acff00..16ceb6222b74b 100644 --- a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from torch.utils.data import Dataset @@ -42,41 +43,8 @@ def __len__(self): return self.len -def test_multiple_eval_dataloaders_tuple(tmpdir): - class TestModel(BoringModel): - def validation_step(self, batch, batch_idx, dataloader_idx): - if dataloader_idx == 0: - assert batch.sum() == 0 - elif dataloader_idx == 1: - assert batch.sum() == 11 - else: - raise Exception("should only have two dataloaders") - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - - def val_dataloader(self): - dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) - dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) - return [dl1, dl2] - - model = TestModel() - model.validation_epoch_end = None - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - enable_model_summary=False, - ) - - trainer.fit(model) - - -def test_multiple_eval_dataloaders_list(tmpdir): +@pytest.mark.parametrize("seq_type", (tuple, list)) +def test_multiple_eval_dataloaders_seq(tmpdir, seq_type): class TestModel(BoringModel): def validation_step(self, batch, batch_idx, dataloader_idx): if dataloader_idx == 0: @@ -89,10 +57,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx): def val_dataloader(self): dl1 = torch.utils.data.DataLoader(RandomDatasetA(32, 64), batch_size=11) dl2 = torch.utils.data.DataLoader(RandomDatasetB(32, 64), batch_size=11) - return dl1, dl2 + return seq_type((dl1, dl2)) model = TestModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, diff --git a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py index fccb4d65e3bd7..8308a59bb312f 100644 --- a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py +++ b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py @@ -40,9 +40,9 @@ class FastDevRunModel(BoringModel): def __init__(self): super().__init__() self.training_step_call_count = 0 - self.training_epoch_end_call_count = 0 + self.on_train_epoch_end_call_count = 0 self.validation_step_call_count = 0 - self.validation_epoch_end_call_count = 0 + self.on_validation_epoch_end_call_count = 0 self.test_step_call_count = 0 def training_step(self, batch, batch_idx): @@ -51,17 +51,15 @@ def training_step(self, batch, batch_idx): self.training_step_call_count += 1 return super().training_step(batch, batch_idx) - def training_epoch_end(self, outputs): - self.training_epoch_end_call_count += 1 - super().training_epoch_end(outputs) + def on_train_epoch_end(self): + self.on_train_epoch_end_call_count += 1 def validation_step(self, batch, batch_idx): self.validation_step_call_count += 1 return super().validation_step(batch, batch_idx) - def validation_epoch_end(self, outputs): - self.validation_epoch_end_call_count += 1 - super().validation_epoch_end(outputs) + def on_validation_epoch_end(self): + self.on_validation_epoch_end_call_count += 1 def test_step(self, batch, batch_idx): self.test_step_call_count += 1 @@ -83,9 +81,9 @@ def test_step(self, batch, batch_idx): def _make_fast_dev_run_assertions(trainer, model): # check the call count for train/val/test step/epoch assert model.training_step_call_count == fast_dev_run - assert model.training_epoch_end_call_count == 1 + assert model.on_train_epoch_end_call_count == 1 assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run - assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1 + assert model.on_validation_epoch_end_call_count == 0 if model.validation_step is None else 1 assert model.test_step_call_count == fast_dev_run # check trainer arguments diff --git a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py index a984c28acc2de..8d1d6e35c1b0a 100644 --- a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py +++ b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py @@ -47,7 +47,6 @@ def training_step(self, *args, **kwargs): match = "`max_epochs` was not set. Setting it to 1000 epochs." model = CustomModel() - model.training_epoch_end = None trainer = Trainer(max_epochs=None, limit_train_batches=1) with pytest.warns(PossibleUserWarning, match=match): trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index 51ce770cbca71..13bbb2243c3ac 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -84,7 +84,6 @@ def test_all_rank_logging_ddp_spawn(tmpdir): """Check that all ranks can be logged from.""" model = TestModel() all_rank_logger = AllRankLogger() - model.training_epoch_end = None trainer = Trainer( strategy="ddp_spawn", accelerator="gpu", diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 612e48e7703d2..d0aadbf756f33 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -56,7 +56,6 @@ def validation_step(self, batch, batch_idx): model = TestModel() model.validation_step_end = None - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -79,7 +78,7 @@ def validation_step(self, batch, batch_idx): def test__validation_step__epoch_end__log(tmpdir): - """Tests that validation_epoch_end can log.""" + """Tests that on_validation_epoch_end can log.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -94,7 +93,7 @@ def validation_step(self, batch, batch_idx): self.log("d", out["x"], on_step=True, on_epoch=True) return out - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self.log("g", torch.tensor(2, device=self.device), on_epoch=True) model = TestModel() @@ -124,7 +123,7 @@ def validation_epoch_end(self, outputs): @pytest.mark.parametrize(["batches", "log_interval", "max_epochs"], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_logging(tmpdir, batches, log_interval, max_epochs): class TestModel(BoringModel): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self.log("c", torch.tensor(2), on_epoch=True, prog_bar=True, logger=True) self.log("d/e/f", 2) @@ -188,10 +187,8 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss, on_epoch=True, on_step=True, prog_bar=True) return {"x": loss} - def validation_epoch_end(self, outputs) -> None: - for passed_in, manually_tracked in zip(outputs, self.val_losses): - assert passed_in["x"] == manually_tracked - self.manual_epoch_end_mean = torch.stack([x["x"] for x in outputs]).mean() + def on_validation_epoch_end(self) -> None: + self.manual_epoch_end_mean = torch.stack(self.val_losses).mean() model = TestModel() trainer = Trainer( @@ -217,10 +214,10 @@ def validation_epoch_end(self, outputs) -> None: @pytest.mark.parametrize(["batches", "log_interval", "max_epochs"], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): - """Tests that test_epoch_end can be used to log, and we return them in the results.""" + """Tests that on_test_epoch_end can be used to log, and we return them in the results.""" class TestModel(BoringModel): - def test_epoch_end(self, outputs): + def on_test_epoch_end(self): self.log("c", torch.tensor(2)) self.log("d/e/f", 2) @@ -255,7 +252,6 @@ def test_dataloader(self): return super().test_dataloader() model = TestModel() - model.test_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -332,7 +328,6 @@ def validation_step(self, batch, batch_idx): self.log("val_loss", loss) model = TestModel() - model.validation_epoch_end = None cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, @@ -459,7 +454,6 @@ def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] model = TestModel() - model.test_epoch_end = None cb = TestCallback() trainer = Trainer( default_root_dir=tmpdir, limit_test_batches=2, num_sanity_val_steps=0, max_epochs=2, callbacks=[cb] @@ -533,7 +527,6 @@ def test_step(self, batch, batch_idx): return {"y": loss} model = ExtendedModel() - model.validation_epoch_end = None # Initialize a trainer trainer = Trainer( @@ -596,6 +589,8 @@ def get_metrics_at_idx(idx): @pytest.mark.parametrize("val_check_interval", [0.5, 1.0]) def test_multiple_dataloaders_reset(val_check_interval, tmpdir): class TestModel(BoringModel): + val_outputs = [[], []] + def training_step(self, batch, batch_idx): out = super().training_step(batch, batch_idx) value = 1 + batch_idx @@ -604,7 +599,7 @@ def training_step(self, batch, batch_idx): self.log("batch_idx", value, on_step=True, on_epoch=True, prog_bar=True) return out - def training_epoch_end(self, outputs): + def on_training_epoch_end(self): metrics = self.trainer.progress_bar_metrics v = 15 if self.current_epoch == 0 else 150 assert metrics["batch_idx_epoch"] == (v / 5.0) @@ -613,10 +608,13 @@ def validation_step(self, batch, batch_idx, dataloader_idx): value = (1 + batch_idx) * (1 + dataloader_idx) if self.current_epoch != 0: value *= 10 + self.val_outputs[dataloader_idx].append(value) self.log("val_loss", value, on_step=False, on_epoch=True, prog_bar=True, logger=True) - return value - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): + outputs = self.val_outputs + self.val_outputs = [[], []] + if self.current_epoch == 0: assert sum(outputs[0]) / 5 == 3 assert sum(outputs[1]) / 5 == 6 @@ -658,6 +656,8 @@ def val_dataloader(self): ) def test_metrics_and_outputs_device(tmpdir, accelerator): class TestModel(BoringModel): + outputs = [] + def on_before_backward(self, loss: Tensor) -> None: # the loss should be on the correct device before backward assert loss.device.type == accelerator @@ -667,13 +667,13 @@ def validation_step(self, *args): y = x * 2 assert x.requires_grad is True assert y.grad_fn is None # disabled by validation - self.log("foo", y) + self.outputs.append(y) return y - def validation_epoch_end(self, outputs): - # the step outputs were not moved - assert all(o.device == self.device for o in outputs) + def on_validation_epoch_end(self): + # the step outputs were not moved after returning them + assert all(o.device == self.device for o in self.outputs) # and the logged metrics aren't assert self.trainer.callback_metrics["foo"].device.type == accelerator @@ -706,7 +706,6 @@ def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)] model = CustomBoringModel() - model.test_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) results = trainer.test(model) @@ -728,12 +727,16 @@ def test_dataloader(self): @mock.patch("lightning.pytorch.loggers.TensorBoardLogger.log_metrics") def test_logging_multi_dataloader_on_epoch_end(mock_log_metrics, tmpdir): class CustomBoringModel(BoringModel): + outputs = [[], []] + def test_step(self, batch, batch_idx, dataloader_idx): - self.log("foo", dataloader_idx + 1) - return dataloader_idx + 1 + value = dataloader_idx + 1 + self.log("foo", value) + self.outputs[dataloader_idx].append(value) + return value - def test_epoch_end(self, outputs) -> None: - self.log("foobar", sum(sum(o) for o in outputs)) + def on_test_epoch_end(self): + self.log("foobar", sum(sum(o) for o in self.outputs)) def test_dataloader(self): return [super().test_dataloader(), super().test_dataloader()] @@ -742,7 +745,7 @@ def test_dataloader(self): trainer = Trainer(default_root_dir=tmpdir, limit_test_batches=1, logger=TensorBoardLogger(tmpdir)) results = trainer.test(model) - # what's logged in `test_epoch_end` gets included in the results of each dataloader + # what's logged in `on_test_epoch_end` gets included in the results of each dataloader assert results == [{"foo/dataloader_idx_0": 1, "foobar": 3}, {"foo/dataloader_idx_1": 2, "foobar": 3}] cb_metrics = set(trainer.callback_metrics) assert cb_metrics == {"foo/dataloader_idx_0", "foo/dataloader_idx_1", "foobar"} @@ -960,9 +963,6 @@ def val_dataloader(self): def test_dataloader(self): return [super().test_dataloader()] * num_dataloaders - validation_epoch_end = None - test_epoch_end = None - limit_batches = 4 max_epochs = 3 trainer = Trainer( diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 8bcb9ee982033..73caa92ae7f54 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -272,12 +272,11 @@ def training_step(self, *args, **kwargs): self.log("train_loss_epoch", result["loss"], on_step=False, on_epoch=True) return result - def training_step_end(self, training_step_outputs): # required for dp - loss = training_step_outputs["loss"].mean() + def training_step_end(self, training_step_output): # required for dp + loss = training_step_output["loss"].mean() return loss - def training_epoch_end(self, outputs): - assert all(out["loss"].device == root_device for out in outputs) + def on_train_epoch_end(self): assert self.trainer.callback_metrics["train_loss_epoch"].device == root_device def validation_step(self, *args, **kwargs): @@ -285,8 +284,7 @@ def validation_step(self, *args, **kwargs): self.log("val_loss_epoch", val_loss, on_step=False, on_epoch=True) return val_loss - def validation_epoch_end(self, outputs): - assert all(loss.device == root_device for loss in outputs) + def on_validation_epoch_end(self): assert self.trainer.callback_metrics["val_loss_epoch"].device == root_device def test_step(self, *args, **kwargs): @@ -294,8 +292,7 @@ def test_step(self, *args, **kwargs): self.log("test_loss_epoch", test_loss, on_step=False, on_epoch=True) return test_loss - def test_epoch_end(self, outputs): - assert all(loss.device == root_device for loss in outputs) + def on_test_epoch_end(self): assert self.trainer.callback_metrics["test_loss_epoch"].device == root_device def train_dataloader(self): @@ -321,37 +318,6 @@ def test_dataloader(self): trainer.test(model) -def test_can_return_tensor_with_more_than_one_element(tmpdir): - """Ensure {validation,test}_step return values are not included as callback metrics. - - #6623 - """ - - class TestModel(BoringModel): - def validation_step(self, batch, *args, **kwargs): - return {"val": torch.tensor([0, 1])} - - def validation_epoch_end(self, outputs): - # ensure validation step returns still appear here - assert len(outputs) == 2 - assert all(list(d) == ["val"] for d in outputs) # check keys - assert all(torch.equal(d["val"], torch.tensor([0, 1])) for d in outputs) # check values - - def test_step(self, batch, *args, **kwargs): - return {"test": torch.tensor([0, 1])} - - def test_epoch_end(self, outputs): - assert len(outputs) == 2 - assert all(list(d) == ["test"] for d in outputs) # check keys - assert all(torch.equal(d["test"], torch.tensor([0, 1])) for d in outputs) # check values - - model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2, enable_progress_bar=False) - trainer.fit(model) - trainer.validate(model) - trainer.test(model) - - @pytest.mark.parametrize("add_dataloader_idx", [False, True]) def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx): """test that auto_add_dataloader_idx argument works.""" @@ -372,7 +338,6 @@ def validation_step(self, *args, **kwargs): return output model = TestModel() - model.validation_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2) trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py index a3c297ed0f538..32114f8ba54bf 100644 --- a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py @@ -70,7 +70,6 @@ def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs): "on_train_start", "on_train_epoch_start", "on_train_epoch_end", - "training_epoch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) @@ -85,7 +84,6 @@ def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs): "on_validation_batch_end", "validation_step", "validation_step_end", - "validation_epoch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) @@ -100,7 +98,6 @@ def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs): "on_test_batch_end", "test_step", "test_step_end", - "test_epoch_end", ] all_logging_hooks = all_logging_hooks - set(hooks) _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index 6362aa54057bf..019d609a8df3c 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -103,8 +103,6 @@ def training_step(self, batch, batch_idx): def test__training_step__epoch_end__log(tmpdir): - """Tests that training_epoch_end can log.""" - class TestModel(BoringModel): def training_step(self, batch, batch_idx): out = super().training_step(batch, batch_idx) @@ -113,9 +111,9 @@ def training_step(self, batch, batch_idx): self.log_dict({"a1": loss, "a2": loss}) return out - def training_epoch_end(self, outputs): - self.log("b1", outputs[0]["loss"]) - self.log("b", outputs[0]["loss"], on_epoch=True, prog_bar=True, logger=True) + def on_train_epoch_end(self): + self.log("b1", torch.tensor(1.0)) + self.log("b", torch.tensor(2.0), on_epoch=True, prog_bar=True, logger=True) model = TestModel() model.val_dataloader = None @@ -144,7 +142,7 @@ def training_epoch_end(self, outputs): @pytest.mark.parametrize(["batches", "log_interval", "max_epochs"], [(1, 1, 1), (64, 32, 2)]) def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval, max_epochs): - """Tests that training_step_end and training_epoch_end can log.""" + """Tests that training_step_end and on_train_epoch_end can log.""" class TestModel(BoringModel): def training_step(self, batch): @@ -156,8 +154,8 @@ def training_step_end(self, out): self.log("b", out, on_step=True, on_epoch=True, prog_bar=True, logger=True) return out - def training_epoch_end(self, outputs): - self.log("c", outputs[0]["loss"], on_epoch=True, prog_bar=True, logger=True) + def on_train_epoch_end(self): + self.log("c", 1, on_epoch=True, prog_bar=True, logger=True) self.log("d/e/f", 2) model = TestModel() @@ -722,9 +720,13 @@ def training_step(self, batch, batch_idx): def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir): class TestModel(BoringModel): def on_train_epoch_end(self): + self.log("on_train_epoch_end", 3.0, reduce_fx="mean") + assert self.trainer._results["on_train_epoch_end.on_train_epoch_end"].value == 3.0 assert all(v == 3 for v in self.trainer.callback_metrics.values()) def on_validation_epoch_end(self): + self.log("on_validation_epoch_end", 3.0, reduce_fx="mean") + assert self.trainer._results["on_validation_epoch_end.on_validation_epoch_end"].value == 3.0 assert all(v == 3 for v in self.trainer.callback_metrics.values()) def on_train_batch_start(self, batch, batch_idx): @@ -739,16 +741,9 @@ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.log("on_validation_batch_end", 1.0, reduce_fx="sum") - def training_epoch_end(self, *_) -> None: - self.log("training_epoch_end", 3.0, reduce_fx="mean") - assert self.trainer._results["training_epoch_end.training_epoch_end"].value == 3.0 - - def validation_epoch_end(self, *_) -> None: - self.log("validation_epoch_end", 3.0, reduce_fx="mean") - assert self.trainer._results["validation_epoch_end.validation_epoch_end"].value == 3.0 - model = TestModel() trainer = Trainer( + default_root_dir=tmpdir, enable_progress_bar=False, limit_train_batches=3, limit_val_batches=3, diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 0e98a962ede55..4c09ec3e1a710 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -74,25 +74,13 @@ def configure_optimizers(self): @pytest.mark.parametrize( "kwargs", [{}, pytest.param({"accelerator": "gpu", "devices": 1, "precision": 16}, marks=RunIf(min_cuda_gpus=1))] ) -def test_multiple_optimizers_manual_no_return(tmpdir, kwargs): - class TestModel(ManualOptModel): - def training_step(self, batch, batch_idx): - # avoid returning a value - super().training_step(batch, batch_idx) - - def training_epoch_end(self, outputs): - # outputs is empty as training_step does not return - # and it is not automatic optimization - assert not outputs - - model = TestModel() - model.val_dataloader = None - +def test_multiple_optimizers_manual_call_counts(tmpdir, kwargs): + model = ManualOptModel() limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=limit_train_batches, - limit_val_batches=2, + limit_val_batches=0, max_epochs=1, log_every_n_steps=1, enable_model_summary=False, @@ -109,58 +97,25 @@ def training_epoch_end(self, outputs): with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 + assert trainer.global_step == limit_train_batches * 2 if kwargs.get("precision") == 16: scaler_step_patch.stop() assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches -def test_multiple_optimizers_manual_return(tmpdir): - class TestModel(ManualOptModel): - def training_step(self, batch, batch_idx): - super().training_step(batch, batch_idx) - return {"something": "else"} - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert outputs == [{"something": "else"}, {"something": "else"}] - - model = TestModel() - model.val_dataloader = None - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=2, - max_epochs=1, - log_every_n_steps=1, - enable_model_summary=False, - ) - - with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: - trainer.fit(model) - assert bwd_mock.call_count == limit_train_batches * 3 - assert trainer.global_step == limit_train_batches * 2 - - def test_multiple_optimizers_manual_log(tmpdir): class TestModel(ManualOptModel): def training_step(self, batch, batch_idx): loss_2 = super().training_step(batch, batch_idx) self.log("a", loss_2, on_epoch=True) - def training_epoch_end(self, outputs) -> None: - assert not outputs - model = TestModel() - model.val_dataloader = None - limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=limit_train_batches, - limit_val_batches=2, + limit_val_batches=0, max_epochs=1, log_every_n_steps=1, enable_model_summary=False, @@ -262,7 +217,6 @@ def test_manual_optimization_and_return_tensor(tmpdir): model = ManualOptimizationExtendedModel() model.training_step_end = None - model.training_epoch_end = None trainer = Trainer( max_epochs=1, @@ -348,7 +302,6 @@ def on_train_epoch_end(self, *_, **__): model = ExtendedModel() model.training_step_end = None - model.training_epoch_end = None trainer = Trainer( max_epochs=1, @@ -396,10 +349,6 @@ def training_step(self, batch, batch_idx): return {"loss1": loss_1.detach(), "loss2": loss_2.detach()} - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - # sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale def on_after_backward(self) -> None: # check grads are scaled @@ -498,7 +447,6 @@ def optimizer_closure(): assert not torch.equal(weight_before, weight_after) model = TestModel() - model.training_epoch_end = None limit_train_batches = 2 trainer = Trainer( @@ -540,8 +488,6 @@ def optimizer_closure(): assert not torch.equal(weight_before, weight_after) model = TestModel() - model.training_epoch_end = None - limit_train_batches = 4 trainer = Trainer( default_root_dir=tmpdir, @@ -615,7 +561,6 @@ def configure_optimizers(self): model = TestModel() model.val_dataloader = None - model.training_epoch_end = None limit_train_batches = 8 trainer = Trainer( @@ -727,8 +672,6 @@ def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationD model = model_cls() model_copy = deepcopy(model) model.val_dataloader = None - model.training_epoch_end = None - limit_train_batches = 8 trainer = Trainer( default_root_dir=tmpdir, @@ -847,7 +790,6 @@ def configure_optimizers(self): return [optimizer_1, optimizer_2], [self.scheduler_1, self.scheduler_2] model = TestModel() - model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, limit_test_batches=1 @@ -864,14 +806,9 @@ def __init__(self, scheduler_as_dict): self.scheduler_as_dict = scheduler_as_dict self.automatic_optimization = False - def training_step(self, batch, batch_idx): - return {"train_loss": torch.tensor([0.0])} - - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): scheduler = self.lr_schedulers() - - loss = torch.stack([x["train_loss"] for x in outputs]).mean() - scheduler.step(loss) + scheduler.step(torch.tensor(0.0)) def configure_optimizers(self): optimizer = torch.optim.SGD(self.parameters(), lr=0.1) @@ -905,7 +842,6 @@ def test_lr_scheduler_step_not_called(tmpdir): """Test `lr_scheduler.step()` is not called in manual optimization.""" model = ManualOptimBoringModel() model.training_step_end = None - model.training_epoch_end = None trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, fast_dev_run=2) @@ -958,7 +894,6 @@ def configure_optimizers(self): return optimizer, optimizer_2 model = TestModel() - model.training_epoch_end = None model.val_dataloader = None trainer = Trainer( diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index 229430d0168c3..c69f416546019 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -65,11 +65,6 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() - def training_epoch_end(self, outputs) -> None: - # outputs is empty as training_step does not return - # and it is not automatic optimization - assert len(outputs) == 0 - model = TestModel() model.val_dataloader = None diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index a7a25358334fd..7d887380fa5d9 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -77,7 +77,6 @@ def configure_optimizers(self): model = Model() model.automatic_optimization = False - model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) trainer.fit(model) @@ -540,7 +539,6 @@ def on_save_checkpoint(self, checkpoint): self.on_save_checkpoint_called = True model = Model() - model.training_epoch_end = None trainer.fit(model) assert model.on_save_checkpoint_called @@ -577,7 +575,6 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler1, lr_scheduler2] model = CustomBoringModel() - model.training_epoch_end = None max_epochs = 3 limit_train_batches = 2 trainer = Trainer( diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index 5f568ab74e553..d25e4f20ee14c 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock + import pytest import torch @@ -18,6 +20,10 @@ from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.trainer.configuration_validator import ( + __verify_eval_loop_configuration, + __verify_train_val_loop_configuration, +) from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.conftest import mock_cuda_count @@ -158,3 +164,28 @@ def custom_method(self, batch, *_, **__): with pytest.raises(MisconfigurationException, match=match_pattern): trainer.fit(model) + + +def test_legacy_epoch_end_hooks(): + class TrainingEpochEndModel(BoringModel): + def training_epoch_end(self, outputs): + pass + + class ValidationEpochEndModel(BoringModel): + def validation_epoch_end(self, outputs): + pass + + trainer = Mock() + with pytest.raises(NotImplementedError, match="training_epoch_end` has been removed in v2.0"): + __verify_train_val_loop_configuration(trainer, TrainingEpochEndModel()) + with pytest.raises(NotImplementedError, match="validation_epoch_end` has been removed in v2.0"): + __verify_train_val_loop_configuration(trainer, ValidationEpochEndModel()) + + class TestEpochEndModel(BoringModel): + def test_epoch_end(self, outputs): + pass + + with pytest.raises(NotImplementedError, match="validation_epoch_end` has been removed in v2.0"): + __verify_eval_loop_configuration(ValidationEpochEndModel(), "val") + with pytest.raises(NotImplementedError, match="test_epoch_end` has been removed in v2.0"): + __verify_eval_loop_configuration(TestEpochEndModel(), "test") diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 6e6463c1499cb..5af5ad752ee9a 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -47,9 +47,6 @@ def val_dataloader(self): def validation_step(self, batch, batch_idx, dataloader_idx): return super().validation_step(batch, batch_idx) - def validation_epoch_end(self, *args, **kwargs): - pass - class MultiTestDataLoaderBoringModel(BoringModel): def test_dataloader(self): @@ -58,9 +55,6 @@ def test_dataloader(self): def test_step(self, batch, batch_idx, dataloader_idx): return super().test_step(batch, batch_idx) - def test_epoch_end(self, *args, **kwargs): - pass - class MultiEvalDataLoaderModel(MultiValDataLoaderBoringModel, MultiTestDataLoaderBoringModel): pass @@ -75,10 +69,8 @@ def test_fit_train_loader_only(tmpdir): model.test_dataloader = None model.validation_step = None - model.validation_epoch_end = None model.test_step = None - model.test_epoch_end = None trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, train_dataloaders=train_dataloader) @@ -94,7 +86,6 @@ def test_fit_val_loader_only(tmpdir): model.test_dataloader = None model.test_step = None - model.test_epoch_end = None trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) @@ -208,7 +199,7 @@ def training_step(self, batch, batch_idx): self.log("loss", self.global_step) return super().training_step(batch, batch_idx) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self.log("val_log", self.current_epoch) @@ -657,7 +648,7 @@ def __init__(self): def training_step(self, batch, batch_idx): self.batches_seen.append(batch) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): world_size = 2 num_samples = NumpyRandomDataset.size all_batches = torch.cat(self.batches_seen) @@ -1050,9 +1041,8 @@ def val_dataloader(self): val_reload_epochs.append(self.current_epoch) return super().val_dataloader() - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): val_check_epochs.append(self.current_epoch) - return super().validation_epoch_end(outputs) model = TestModel() @@ -1274,17 +1264,6 @@ def predict(self, batch, batch_idx, dataloader_idx): self.assert_dataloader_idx_hook(dataloader_idx) return super().predict(batch, batch_idx, dataloader_idx) - def assert_epoch_end_outputs(self, outputs, mode): - assert len(outputs) == 2 - assert all(f"{mode}_loss_0" in x for x in outputs[0]) - assert all(f"{mode}_loss_1" in x for x in outputs[1]) - - def validation_epoch_end(self, outputs): - self.assert_epoch_end_outputs(outputs, mode="val") - - def test_epoch_end(self, outputs): - self.assert_epoch_end_outputs(outputs, mode="test") - def train_dataloader(self): return {"a": DataLoader(RandomDataset(32, 64)), "b": DataLoader(RandomDataset(32, 64))} diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 7af11527f2a46..38546e40a73e8 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -808,7 +808,6 @@ def predict_step(self, batch, *_): return self(batch) model = TestModel() - model.test_epoch_end = None trainer = Trainer( max_epochs=2, limit_val_batches=1, @@ -879,7 +878,6 @@ def predict_step(self, batch, *_): return self(batch) model = TestModel() - model.test_epoch_end = None trainer = Trainer( max_epochs=2, limit_val_batches=1, @@ -930,16 +928,11 @@ def test_disabled_training(tmpdir): class CurrentModel(BoringModel): training_step_invoked = False - training_epoch_end_invoked = False def training_step(self, *args, **kwargs): self.training_step_invoked = True return super().training_step(*args, **kwargs) - def training_epoch_end(self, *args, **kwargs): - self.training_epoch_end_invoked = True - return super().training_epoch_end(*args, **kwargs) - model = CurrentModel() trainer_options = dict( @@ -965,7 +958,6 @@ def training_epoch_end(self, *args, **kwargs): assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch == 0 assert not model.training_step_invoked, "`training_step` should not run when `limit_train_batches=0`" - assert not model.training_epoch_end_invoked, "`training_epoch_end` should not run when `limit_train_batches=0`" # check that limit_train_batches has no influence when fast_dev_run is turned on model = CurrentModel() @@ -983,7 +975,6 @@ def training_epoch_end(self, *args, **kwargs): assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch == 1 assert model.training_step_invoked, "did not run `training_step` with `fast_dev_run=True`" - assert model.training_epoch_end_invoked, "did not run `training_epoch_end` with `fast_dev_run=True`" def test_disabled_validation(tmpdir): @@ -992,16 +983,11 @@ def test_disabled_validation(tmpdir): class CurrentModel(BoringModel): validation_step_invoked = False - validation_epoch_end_invoked = False def validation_step(self, *args, **kwargs): self.validation_step_invoked = True return super().validation_step(*args, **kwargs) - def validation_epoch_end(self, *args, **kwargs): - self.validation_epoch_end_invoked = True - return super().validation_epoch_end(*args, **kwargs) - model = CurrentModel() trainer_options = dict( @@ -1020,7 +1006,6 @@ def validation_epoch_end(self, *args, **kwargs): assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch == 2 assert not model.validation_step_invoked, "`validation_step` should not run when `limit_val_batches=0`" - assert not model.validation_epoch_end_invoked, "`validation_epoch_end` should not run when `limit_val_batches=0`" # check that limit_val_batches has no influence when fast_dev_run is turned on model = CurrentModel() @@ -1031,7 +1016,6 @@ def validation_epoch_end(self, *args, **kwargs): assert trainer.state.finished, f"Training failed with {trainer.state}" assert trainer.current_epoch == 1 assert model.validation_step_invoked, "did not run `validation_step` with `fast_dev_run=True`" - assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`" @pytest.mark.parametrize("track_grad_norm", [0, torch.tensor(1), "nan"]) @@ -1166,7 +1150,6 @@ def val_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] model = CustomModel() - model.validation_epoch_end = None num_sanity_val_steps = 4 trainer = Trainer( @@ -1182,7 +1165,6 @@ def val_dataloader(self): return [DataLoader(RandomDataset(32, 64), batch_size=8), DataLoader(RandomDataset(32, 64))] model = CustomModelMixedVal() - model.validation_epoch_end = None with patch.object( trainer.fit_loop.epoch_loop.val_loop.epoch_loop, @@ -1208,7 +1190,6 @@ def val_dataloader(self): return [DataLoader(RandomDataset(32, 64)), DataLoader(RandomDataset(32, 64))] model = CustomModel() - model.validation_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=-1, limit_val_batches=limit_val_batches, max_steps=1 ) @@ -1725,9 +1706,6 @@ def validation_step(self, batch, batch_idx): loss = self.step(batch) self.log("x", loss) - def validation_epoch_end(self, outputs) -> None: - pass - @RunIf(skip_windows=True) def test_fit_test_synchronization(tmpdir): diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 6326bf00d0eb6..2e12ea172949f 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -459,7 +459,6 @@ def train_dataloader(self): model = CustomBatchSizeModel(batch_size=16) model.validation_step = None - model.training_epoch_end = None scale_batch_size_kwargs = {"max_trials": 10, "steps_per_trial": 1, "init_val": 500, "mode": scale_method} trainer = Trainer(default_root_dir=tmpdir, max_epochs=2) diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 2ccbdfe38a224..80f765290a99f 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -53,20 +53,19 @@ def test_all_gather_ddp_spawn(): def test_all_gather_collection(tmpdir): class TestModel(BoringModel): - training_epoch_end_called = False + on_train_epoch_end_called = False - def training_epoch_end(self, outputs) -> None: - losses = torch.stack([x["loss"] for x in outputs]) + def on_train_epoch_end(self): + losses = torch.rand(2, 2).t() gathered_loss = self.all_gather( { - "losses_tensor_int": torch.rand(2, 2).int().t(), - "losses_tensor_float": torch.rand(2, 2).t(), + "losses_tensor_int": losses.int(), + "losses_tensor_float": losses, + "losses_tensor_list": [losses, losses], "losses_np_ndarray": np.array([1, 2, 3]), "losses_bool": [True, False], "losses_float": [0.0, 1.0, 2.0], "losses_int": [0, 1, 2], - "losses": losses, - "losses_list": [losses, losses], } ) assert gathered_loss["losses_tensor_int"][0].dtype == torch.int32 @@ -76,9 +75,16 @@ def training_epoch_end(self, outputs) -> None: assert gathered_loss["losses_bool"][0].dtype == torch.uint8 assert gathered_loss["losses_float"][0].dtype == torch.float assert gathered_loss["losses_int"][0].dtype == torch.int - assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) - assert gathered_loss["losses"].numel() == 2 * len(losses) - self.training_epoch_end_called = True + + losses_numel = losses.numel() + assert gathered_loss["losses_tensor_int"].numel() == 2 * losses_numel + assert gathered_loss["losses_tensor_float"].numel() == 2 * losses_numel + assert torch.stack(gathered_loss["losses_tensor_list"]).shape == (2, 2, 2, 2) + assert gathered_loss["losses_np_ndarray"].numel() == 2 * 3 + assert torch.stack(gathered_loss["losses_bool"]).shape == (2, 2) + assert torch.stack(gathered_loss["losses_float"]).shape == (3, 2) + assert torch.stack(gathered_loss["losses_int"]).shape == (3, 2) + self.on_train_epoch_end_called = True model = TestModel() trainer = Trainer( @@ -96,7 +102,7 @@ def training_epoch_end(self, outputs) -> None: enable_checkpointing=False, ) trainer.fit(model) - assert model.training_epoch_end_called + assert model.on_train_epoch_end_called @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index 9a880dc9da995..d6a4c96e5f530 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -55,11 +55,11 @@ def validation_step(self, batch, batch_idx): self._signal() return super().validation_step(batch, batch_idx) - def training_epoch_end(self, outputs) -> None: + def on_train_epoch_end(self): if not self.failure_on_step and self.failure_on_training: self._signal() - def validation_epoch_end(self, outputs) -> None: + def on_validation_epoch_end(self): if not self.failure_on_step and not self.failure_on_training: self._signal() @@ -127,7 +127,7 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on if val_check_interval == 1.0: status = "_FitLoop:on_advance_end" else: - # `training_epoch_end` happens after `validation_epoch_end` since Lightning v1.4 + # `on_train_epoch_end` happens after `on_validation_epoch_end` since Lightning v1.4 status = "_FitLoop:on_advance_end" if failure_on_training else "_TrainingEpochLoop:on_advance_end" _fit_model(tmpdir, True, val_check_interval, failure_on_step, failure_on_training, on_last_batch, status=status) diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index 55d26c5c88ff5..3d3aaee498b06 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -195,10 +195,10 @@ def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: torch.cuda._sleep(self.CYCLES_PER_MS * 50) return batch - def training_step_end(self, training_step_outputs): + def training_step_end(self, training_step_output): # emulate heavy routine torch.cuda._sleep(self.CYCLES_PER_MS * 50) - return training_step_outputs + return training_step_output def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=0.1) @@ -244,7 +244,7 @@ def training_step(self, dataloader_iter, batch_idx): loss.backward() opt.step() - def training_epoch_end(self, *_): + def on_train_epoch_end(self): assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 assert self.trainer.fit_loop._data_fetcher.fetched == 64 assert self.count == 64 @@ -456,8 +456,6 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): def val_dataloader(self): return [super().val_dataloader(), super().val_dataloader()] - validation_epoch_end = None - model = MyModel() fast_dev_run = 2 trainer = Trainer(