Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Change input format of training_epoch_end hook in case of multiple optimizers or TBPTT #9737

Closed
awaelchli opened this issue Sep 28, 2021 · 0 comments · Fixed by #12182
Closed
Assignees
Labels
deprecation Includes a deprecation refactor
Milestone

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Sep 28, 2021

Proposed refactoring or deprecation

Change the format in the outputs that the training_epoch_end and training_step_end hooks receive. This is a breaking change and requires careful deprecation. It will affect users with either multiple optimizers or truncated backprop IF they also implement the aforementioned hooks.

Motivation

When using multiple optimizers or truncated backprop (or both) the inputs passed to the training_epoch_end hook are the outputs from the training step arranged in a 2D or 3D nested list of lists. The shape of this multi-dim array is
(num_optimizers, num_batches, num_tbptt_splits) in the general case and when num_optimizers=1 or truncated backprop is deactivated, the dimensions get squeezed. The problem with this is that the order of these dimensions does not correspond to the loop structure:

for batch in dataloader:
    for split in batch:
        for opt in optimizers:
            ...

It means this output format will never generalize for loop customization as the ordering is arbitrary.
Currently, this permutation of dimensions is hard-coded and will break for custom loops.

Pitch

Deprecate the current format and make it consistent with the loop structure, meaning, we adopt the format
(num_batches, num_tbptt_splits, num_optimizers). This corresponds 1:1 with the loop structure. The standardization here will unblock custom loops with arbitrary nesting and output aggregation with less effort.

Proposed deprecation plan:

  1. In 1.5, log a message that the format will change in the future (if using multiple optimizers and hook is overridden)

  2. In 1.5, the user will change their code given our recommendation and will signal this by adding a new argument to the hook:

    def training_step_end(self, outputs, new_format=True):
        ...
    
    def training_epoch_end(self, outputs, new_format=True):
        ...

    This will trigger the loop to call the hook with the new format for outputs instead of the old one.

  3. In 1.7, the new format will be used unconditionally and be a breaking change if users did not adapt their code until now. The argument new_format=True/False will become ineffective and can be removed again.

Note: The only purpose the new_format argument serves is for inspection by our loop to infer what the user expects to get. We will not pass a value so the user must make it a keyword argument.

Alternative deprecation plan:

Instead of a new_format argument in the signature, one can also add properties to the LightningModule:

class MyLightningModule(LightningModule):
    def __init__(self):
         self.v1_7_training_epoch_end_format = True

If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deprecation Includes a deprecation refactor
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

2 participants