-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RLHF with PPO #1005
RLHF with PPO #1005
Conversation
…ajectory generation, tests for advantage and return estimation
…ng from checkpoints
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1005
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 4e6be43 with merge base 5019074 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…, adding support for saving value head checkpoints
…h padded inputs, added tests for new generation
…rejection sampling and reward model masking, moved utils
…tion sampling masking, tests for get_causal_mask
…l files, added tests for ppo collation
…loss coefficient and refactored kl controllers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is shaping up nicely - I very much like that we de-scoped this to focus on full finetune first with lora and qlora as follow-ups. A couple high level things to note:
-
I'm concerned about the configs becoming too bloated and would like to discuss how to minimize storing lots of logic there.
-
What are the largest size model you can fit in 80G A100? I see you include configs for both 7B and 1B?
Thanks for another review.
I feel this. One thing that stuck out to me when writing this - we currently need 4 checkpointers, two of which are solely used to point to the original weights for the policy and reward models, respectively. They're necessary because you need the reference to the original weights when resuming training, and the choice to me at the time was managing this state in the config vs the checkpoints. The model definitions are also taking up a lot of space, but that's largely because I didn't see another obvious way to configure a 1B Llama2. The model definition in the Mistral config is annoying because that specific reward model uses a different vocab size. Please let me know if I can make this cleaner! There's also ~30 lines for hyperparameters in the config. Hopefully this won't be overwhelming to the user once we include a cookbook. I could remove 5 or so of these from the config and set as defaults in the recipe. EDIT: I could also liberally use
The training run in my above comment trained Mistral 7B on an 80GB A100. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1005 +/- ##
==========================================
- Coverage 69.32% 67.96% -1.37%
==========================================
Files 233 246 +13
Lines 10593 11434 +841
==========================================
+ Hits 7344 7771 +427
- Misses 3249 3663 +414 ☔ View full report in Codecov by Sentry. |
bump bump @joecummings @ebsmothers. My dearest reviewers, Sorry to ping you when you're busy. In my defense, @kartikayk did tell me to. What can I do to help move this along? I'm more than happy to help reduce the review overhead if I can. |
Outstanding discussions/tasks:
|
tests/test_utils.py
Outdated
@@ -29,6 +29,7 @@ | |||
"llama2_tune": "/tmp/test-artifacts/small-ckpt-tune-03082024.pt", | |||
"llama2_meta": "/tmp/test-artifacts/small-ckpt-meta-03082024.pt", | |||
"llama2_hf": "/tmp/test-artifacts/small-ckpt-hf-03082024.pt", | |||
"llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-12072024.pt", # TODO (SalmanMohammadi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for being an American chauvinist but I changed the filename to small-ckpt-hf-reward-07122024.pt
(really just want to make it consistent with the format of the other ones). Also I think you will need to update cache_artifacts.sh correspondingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh how the Empire has fallen from grace.
General comment on the checklist you left earlier: all the points look good to me, let's just file tasks for some of the more important todos that don't have them already. Also, leaving some miscellaneous remarks here in response to several of your previous comments:
Looking at the figures seems this is necessary even for A100? Since you are still pretty close to 80GB allocated memory. I'm also curious whether the overall training speed is decent as these configs can slow things down quite a bit.
Just want to confirm: will we actually be able to run this recipe without batched generation support? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK a bunch more comments but after that there are no major concerns. Home stretch here -- thanks again for your immense patience on this one
(seq_lens > 0) & (seq_lens < self._max_generated_tokens - 1), | ||
seq_lens + 1, | ||
seq_lens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry not sure I fully follow what the purpose of this is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...
...
...
Dare I say..... excalidraw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all seriousness, I can send you an equally confusing diagram from my notes on Discord. This took me a while to wrap my head around, and longer to explain coherently (disclaimer, this could just all be wrong, since my only reference is a single line from a Learning to Summarize implementation), so, thanks for the nerd snipe.
The TL;DR - the value function is estimating the return for the whole sequence at each step, which is the reward model score for the (query, truncated response), plus the KL per-token penalty. We want to use this for the advantage estimation, and the advantage for the last action taken (the last valid non-padding token generated by the model), is:
So, we need the value estimate (return) for the sequence up to now, plus one step ahead. For the last token, this means we need extend the padding mask out by one for the values - instead of masking everything after the last non-padding token, we mask everything one value after the last non-padding token.
These three lines do this, but just add some logic to say if we're already at the end of the sequence then we don't need to extend the mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think this makes sense (though I reserve the right to be confused again later on)
O' what glorious reviews. Thank you. I've addressed your comments.
For 7B, since we're fitting 4x7B models in it took a little wrangling to fit it all. The run I posted took around ~3 hours. I haven't found any tests to benchmark against here on comparable hardware, to estimate appropriate speed/memory usage. DeepSpeed's RLHF states:
Not 100% clear on whether the upper parameter bound calculation requires the specific config they listed, but going by that, their method offers ~13GB for the sum of the actor and critic models, max, on an 80GB A100. TRL trained Pythia 6.9B on their PPOV2 trainer with 8xH100.
The generation utils I include in this PR do provide batched generation support. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your immense patience on this one. I left a couple of other follow-up comments, but none of them are blocking us from landing this. 🚀
Context
What is the purpose of this PR? Is it to
#812
Background reading:
The N Implementation Details of RLHF with PPO, Huang et al.
The N+ Implementation Details of RLHF with PPO: A case study on TL;DR Summarization
The original RLHF paper - Fine-Tuning Language Models from Human Preferences, Ziegler et al.
Anthropic's RLHF paper - Training a Helpful and Harmless Assistant with Learning from Human Feedback
Training language models to follow instructions with human feedback, Ouyang et al.
Shameless plug, but I would have genuinely found this post helpful when I started out with PPO, even for skimming through some of the references - The theory of Proximal Policy Optimization implementations
Changelog:
torchtune.models.mistral
:TransformerDecoder
toTransformerDecoderWithHiddenLayer
TransformerLM
which wraps an output projection aroundTransformerDecoderWithHiddenLayer
TransformerLM
TransformerLMWithValueHead
, with two linear projections: one for the LM head and one for the value head.TransformerDecoderWithHiddenLayer
addedMistralClassifier
.test_ppo_loss
tests for correct behaviour based on expected relative value and policy loss for different inputs.utils.ppo_utils
for various ppo utils, and tests for all files including:_generation.py
custom_generate_next_token
functions for generating with value head models, and for generating with masks and input positions.get_causal_masks
for creating masks of shape[bsz, seq_len, seq_len]
which correctly mask leading padding tokens, suitable for use withscaled_dot_product_attention
.generate
function which generates sequences using above functionality.collate.py
rewards.py
TODO:
TODO (@SalmanMohammadi)
Adding this to open up discussion and get some feedback (@kartikayk) while I train models and verify correctness. Maybe a good place to start would be the TransformerDecoder refactor?
closes #812
pre-commit install
)pytest tests
pytest tests -m integration_test