-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PPO Performance Improvements #2066
base: main
Are you sure you want to change the base?
PPO Performance Improvements #2066
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2066
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 92927c4 with merge base f2bd4bc (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Geez! >3x improvement is no joke. I don't think i will have time to review it this week. But I am very curious to see the changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YOU SHALL NOT PASS
@@ -94,15 +94,15 @@ def generate_next_token( | |||
- tokens (torch.Tensor): tensor with the generated tokens, | |||
with shape [bsz x 1]. | |||
- logits (torch.Tensor): tensor with the logits associated with the generated tokens, | |||
with shape [bsz x seq_length x vocab_size]. | |||
with shape [bsz x 1 x vocab_size]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, this is a BC breaking change for a public API which means we need to deprecate accordingly. Can you make this a flag that is enabled for the PPO use case, then add a deprecation warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid this is going to be challenging to do without introducing graph breaks during compilation. I generally agree with you, though in this case I'm not 100% sure who would be using logits returned from this function outside of PPO
@@ -355,8 +355,8 @@ def generate( | |||
# if incremental decoding is enabled, we can use the current position | |||
# otherwise, we take the whole sequence up to the current position | |||
if incremental_decoding: | |||
curr_input_pos = input_pos[:, curr_pos] | |||
curr_masks = masks[:, curr_pos, None, :] | |||
curr_input_pos = input_pos[:, curr_pos].contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is making a copy of the tensor? So is it slower when not compiling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this was faster in compile as it avoids recompiles on the mask and input_pos strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I definitely believe that, but how does it compare when not compiling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still pretty minimal but YMMV depending on bsz and sequence length. When I last profiled it:
The little black sliver is the .contiguous
call, which takes around 50us every step (some napkin math means this overhead is roughly 50us * max_generated_tokens - 1
) compared to the 4.562s for generating the entire sequence, so a minimal portion of time.
@@ -189,7 +189,7 @@ def get_position_ids_from_padding_mask( | |||
return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int) | |||
|
|||
|
|||
@torch.inference_mode() | |||
@torch.no_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? Truly don't fully understand the difference here lol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inference_mode
changes the attributes of the tensors which will trigger unnecessary recompiles without really being that useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to define "without really being that useful"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll need to double check where I found this - I think it was in a PyTorch dev podcast, but when it was released PyTorch folks mentioned up to ~5% improvement gains on deployed models internally at FB. HF PRs for including inference_mode
in generation didn't really find speedups to the same degree, so they still use no_grad
for generation. To expand on my point above, under compile inference_mode
tensors have different metadata properties and we trigger recompiles when guards are created on these properties which results in increased warmup time.
) | ||
# note that if mask_sum == 1, then there is a division by zero issue | ||
# to avoid it you just need to use a larger minibatch_size | ||
mask_sum = mask.sum() + 1e-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point maybe we make the added value configurable rather than 1e-8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure when someone would want to consider configuring this value
Is there anything else blocking this from landing? @joecummings |
Closes #1425
This PR provides various performance improvements to our PPO single device recipe.
*The models were trained over approx. 37M tokens (~65k samples w/
Due to the non-determinism of the training process curves may look slightly different.max_seq_len=512
) on a single A100 GPU.Changelog:
generation.generate
now only returns logits over the generated tokens rather than the whole sequence - significantly reduces peak memory usage. Tests have been updated.parents=True
tooutput_dir.mkdir
in our checkpointers. We use nested checkpoint folders for PPO e..goutput_dir/policy/
,output_dir/value/
.training.compile_model
- this results in ~10 recompile warnings, which means we need to increase the compile cache size limit - I've addedtorch._dynamo.config.cache_size_limit = 16
at the top of the recipe.I landed on option 2 - it's similar to how we integrate compile with the rest of our recipes, and it eliminates the small warmup overhead. To fully realize compile speedups it's recommended to do a small warm-run of the recipe with compile enabled.