-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Online batch generation for unsupervised learning #2453
Comments
Hi! thanks for your contribution!, great first issue! |
Hi @djbyrne This is already possible. in your LightningModules |
Hey @justusschock you are correct that this is already possible, I have actually been doing something like this for the RL models in the lightning bolts repo https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/rl/common/experience.py Although this works, it seems a bit cumbersome for users to go off an create their own data source that feeds the iterator for the dataloader. I know that this will be doing practically the same thing, but I feel that it would be cleaner and provide a much better user experience. |
I see :) I think you are right. I'm not familiar with RL, so to clarify the interface: Do you need the function to produce whole batches or just single samples? Because for whole batches, we would also have to make sure, that batching in the DataLoader (if used) is disabeled/don't use the DataLoader at all... |
It would be both. The typical flow would be 1: carry out N steps in the environment using the agent 2: add to a buffer or just gather the experience data 3: sample from the buffer or take the gathered experience data as the batch. I think the best way would be to have an IterableDataset that consumes this train_batch() hook. I believe that by doing this, batching would not be used. |
If you use IterableDataset and pass it to the loader, I think batching will still be done. We could however bypass this, with a custom |
This is something I am looking into as well. @djbyrne The code solves MountainCarContinuous-v0 in 100 epochs... which is quite nice :-) //Christofer |
@christofer-f that looks awesome! |
@justusschock I have pushed a proof of concept for the train_batch interface here https://github.com/djbyrne/pytorch-lightning-bolts/blob/enhancement/train_batch_function/pl_bolts/models/rl/vanilla_policy_gradient_model.py The datamodules for this can be found here https://github.com/djbyrne/pytorch-lightning-bolts/blob/enhancement/train_batch_function/pl_bolts/datamodules/experience_source.py This is still WIP but would like to get some feedback on it so far |
I found this article: It seems really useful... Hopefully, I will put some code in my repo to show my ideas... |
@christofer-f yeah that iterable dataset type is what is currently being used in the Pytorch Lightning Bolts RL module |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
🚀 Feature
As an alternative to pulling batches from a pre-made dataset, provide users with a hook that allows them to specify how batches are generated for each training step. This would allow users to provide logic for generating a batch of during online during training. This would be ideal for Reinforcement Learning models and other unsupervised problems.
Motivation
Currently Lightning expects all models to utilize pre-made dataset which is used to generate mini batches through a DataLoader. This works for most use cases, but for some case, such as Reinforcement Learning models, this requirement is counter intuitive as the data is generated online during training. In order for these models to work with Lightning, custom Datasets and runners need to be create to generate the batch and wrap it in a DataLoader.
Pitch
Add an additional hooks for specifying the generation of a batch for train, val and test. Going forward I think there are two options to implement this feature.
1: Check if the user has populated the train_batch function and call this instead of the dataloader directly for the training loop. This will probably require a lot of code churn in order to handle the various checks for the dataloader in the main Trainer
2: Create a 'dummy' dataloader that wraps the train_batch function provided by the Lightning model. This is probably a simpler implementation as the underlying logic in the Trainer that handles the dataloader shouldn't need to change.
Alternatives
The alternative to this would be to have the user create custom DataLoader/Datasets that generate the batch on the fly. This will give the same end result, but is arguably a messier solution and requires more work and effort for the user.
The text was updated successfully, but these errors were encountered: