Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add predict_epoch_end hook. #9380

Closed
rohitgr7 opened this issue Sep 8, 2021 · 8 comments · Fixed by #16520
Closed

Add predict_epoch_end hook. #9380

rohitgr7 opened this issue Sep 8, 2021 · 8 comments · Fixed by #16520
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on

Comments

@rohitgr7
Copy link
Contributor

rohitgr7 commented Sep 8, 2021

🚀 Feature

Motivation

Motivation: #9379
Also, I remember it's a TODO somewhere.

Pitch

The hook will be similar to {val/test}_epoch_end but it will return the outputs.
Also, should we update the signature of on_predict_epoch_end to not accept the outputs, since they don't actually return anything so even if someone wants to modify the predictions, it won't do have any effect on the original predictions.

Alternatives

Can't think of any.

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

@rohitgr7 rohitgr7 added feature Is an improvement or enhancement help wanted Open to be worked on design Includes a design discussion labels Sep 8, 2021
@rohitgr7 rohitgr7 self-assigned this Sep 8, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Sep 8, 2021

There are a number of issues with prediction right now that are at least blocking FB's usage of Trainer.predict

  1. Inconsistent API around outputs, as mentioned here: Inconsistent API for on_predict_epoch_end #8479
  2. Predictions are by default stored & returned: https://github.com/PyTorchLightning/pytorch-lightning/blob/a079d7fccc0a9be25b40296f2a348c4b4f40c8cf/pytorch_lightning/trainer/trainer.py#L793-L794 . API wise this is inconsistent with validate and test. More critically for us, this risks OOMs for large-scale prediction unless users are careful to disable this flag.
  3. The Trainer is currently inconsistent around checks for batch samplers. This first checks if the dataloader has a batch sampler before applying the wrapper for prediction. https://github.com/PyTorchLightning/pytorch-lightning/blob/8407238d66df14c4476f880a6e6260b4bfa83b40/pytorch_lightning/trainer/data_loading.py#L161-L171

But this unconditionally accesses the attribute: https://github.com/PyTorchLightning/pytorch-lightning/blob/8407238d66df14c4476f880a6e6260b4bfa83b40/pytorch_lightning/loops/epoch/prediction_epoch_loop.py#L163-L164

a quickfix could be to check if the dataloader has a batch sampler here: https://github.com/PyTorchLightning/pytorch-lightning/blob/8407238d66df14c4476f880a6e6260b4bfa83b40/pytorch_lightning/loops/epoch/prediction_epoch_loop.py#L163-L164

  1. RE: epoch end hooks, [RFC] Deprecate the _epoch_end hooks #8731 has more discussion on this. I personally think we should not be adding these hooks and instead ask users to either store what is currently returned from predict_step inside the lightning module, or have callbacks do the post-processing in on_predict_bach_end . storing data in the trainer doesn't directly use can quickly lead to bugs (if we do some post-processing wrong) or performance slowdowns (if we use more memory than we need to)

cc @tchaton

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Sep 9, 2021

accumulating predictions is pretty much just some boilerplate code in usual cases, and if lightning can provide it on the fly, then I think predict_epoch_end is a useful hook to have. Atleast, there should be some default structure defined within that users can rely on without writing the same duplicate logic since accumulating predictions with different dataloaders isn't that trivial for everyone (atleast for starters). This is ofcourse optional though. If users want they can write their own logic too and disable return_predictions.

@rohitgr7
Copy link
Contributor Author

rohitgr7 commented Sep 9, 2021

+1 for deprecating outputs from on_predict_epoch_end for consistency with other hooks if we implement predict_epoch_end.

@tchaton
Copy link
Contributor

tchaton commented Sep 9, 2021

  1. I believe this is reasonable as you expect predictions to be returned when performing predict and there is simple way to opt-out. We would add a warning to inform the users it might cause OOM as predictions are stored and advise for BasePredictionWriter alternative.

  2. Good catch !

  3. I believe this is a question of simplification of accessibility vs engineering simplification. I think it is intuitive for the predictions to be saved, but we might want to re-think the API for real world use-case.

@m13uz
Copy link

m13uz commented Aug 4, 2022

So what is the status of this feature?

@dagap
Copy link

dagap commented Oct 18, 2022

+1 on predict_epoch_end which allows one to modify the outputs.

@carmocca carmocca added this to the future milestone Nov 7, 2022
@Borda Borda self-assigned this Nov 7, 2022
@CompRhys
Copy link
Contributor

I would also benefit from this feature!

@sofroniewn
Copy link

I would also benefit from this feature, curious if there is any update on plans here. Or I'm curious if there are any alternatives/ best practices that people have adopted that I could learn from. My use case is the same as in #9379.

I could imagine using on_predict_epoch_end to do my post-processing and store the results on my LightningModule, but I don't like that as much as I quite like the syntax of predictions = trainer.predict(lm, dm) and find it a bit strange to store data on the LightningModule itself.

Thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants