Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient scaling to account for world_size normalization #2172

Merged

Conversation

mirceamironenco
Copy link
Contributor

@mirceamironenco mirceamironenco commented Dec 18, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • FSDP/FSDP2/DDP will normalize the gradients by world_size when performing all_reduce. For a sequence processing task where the desired loss is scaled by the total number of non-padded & non-ignored tokens this requires this normalization be undone. For example if world_size = 2, and we have 2 sets A, B of gradient producing tokens, the total loss we desire is loss(A) + loss(B) / (|A| + |B|) where the normalization factor 1 / (|A| + |B|) is currently being handled by scale_grads:
    training.scale_grads(self._model, 1 / num_tokens)

If A, B are processed on separate data parallel workers the current gradients would be produced by loss(A) / 2 + loss(B) / 2, and with the normalization done as before our loss becomes (loss(A) + loss(B)) / (2 * (|A| + |B|)). This PR accounts for world_size cancelling out the scaling factor.

I haven't seen very large differences wrt loss curves in my preliminary experiments after this change:

Screenshot 2024-12-18 at 16 31 14
Screenshot 2024-12-18 at 16 31 41

Where world_size means the gradient scaling factor is world_size / num_tokens and otherwise 1 / num_tokens. The commands to replicate these plots being:

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_2/3B_full metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=llama3.23b_fix metric_logger.name=world_size dataset.packed=True tokenizer.max_seq_len=512 compile=True

tune run --nproc_per_node 2 full_finetune_distributed.py --config configs/llama3_2/3B_full metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=llama3.23b_fix_noprompt metric_logger.name=world_size dataset.packed=True dataset.train_on_input=False tokenizer.max_seq_len=512 compile=True

Someone with more compute budget can probably get a better idea of the effect for larger models.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Dec 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2172

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a6dc03a with merge base 27fd3a1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 18, 2024
@ebsmothers
Copy link
Contributor

Thanks @mirceamironenco for finding this bug and for making the fix! Apologies for the delay in getting back to it, I wanted to put together a minimal repro to validate this myself (I trust the code pointers, but I like seeing numerical parity). So I put together the following script(s) to convince myself. Can confirm that on identical toy models with identical data we see (grad on N devices) == (grad on single device) / N. Let me run some more experiments to see to what extent this will affect loss curves on larger world sizes. If there is an impact, we should give people an fyi before landing. Will get back to you soon once I run the experiments!

@EugenHotaj
Copy link
Contributor

@ebsmothers any updates on this? We've also seen this in our mulit-node runs -- our grad norms are significantly smaller than what we see from other frameworks (e.g. NeMo).

@ebsmothers
Copy link
Contributor

Hey @EugenHotaj thanks for the bump -- yes, we plan to land this soon. Actually the main reason for being slow on this PR (besides the holidays and PSC) is that we wanna be careful about breaking people who have e.g. their LR tuned to this setting. Ultimately I think we need to just rip the bandaid off and make the fix, then put comms here and in our Discord. Let me try to review and land later today

@mirceamironenco mirceamironenco force-pushed the fix-world-size-normalization branch from 34906b2 to a6dc03a Compare January 7, 2025 21:56
@ebsmothers
Copy link
Contributor

Thanks for your patience @mirceamironenco. Just ran some quick experiments on my end on a single node with 8 GPUs. Attaching some plots below, WandB project is here. There are three runs: one on main, one on this PR, and one on this PR with learning rate scaled by 1/8.

Screenshot 2025-01-07 at 3 13 44 PM Screenshot 2025-01-07 at 3 16 10 PM

Unsurprisingly, it's similar to what @EugenHotaj mentioned -- the grad norm is off, almost exactly by a factor of 8. At least for my case the loss curves are pretty much identical too, not sure if there's a noticeable difference on multinode.

@ebsmothers ebsmothers self-requested a review January 7, 2025 23:25
@EugenHotaj
Copy link
Contributor

At least for my case the loss curves are pretty much identical too

@ebsmothers I've noticed this as well on my runs and found it a bit surprising. Is it expected that the losses would be identical? The gradients point in the same direction but I would have thought we'd see some divergence after taking a few hundred gradient steps. I guess gradient clipping / LR accounts for a lot of this?

@codecov-commenter
Copy link

codecov-commenter commented Jan 8, 2025

Codecov Report

Attention: Patch coverage is 0% with 8 lines in your changes missing coverage. Please review.

Project coverage is 23.95%. Comparing base (213f386) to head (a6dc03a).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 2 Missing ⚠️
recipes/qat_distributed.py 0.00% 2 Missing ⚠️
recipes/knowledge_distillation_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
recipes/lora_finetune_distributed_multi_dataset.py 0.00% 1 Missing ⚠️
recipes/qat_lora_finetune_distributed.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (213f386) and HEAD (a6dc03a). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (213f386) HEAD (a6dc03a)
9 3
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2172       +/-   ##
===========================================
- Coverage   65.41%   23.95%   -41.47%     
===========================================
  Files         344      352        +8     
  Lines       20658    20847      +189     
===========================================
- Hits        13514     4993     -8521     
- Misses       7144    15854     +8710     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers
Copy link
Contributor

Is it expected that the losses would be identical? The gradients point in the same direction but I would have thought we'd see some divergence after taking a few hundred gradient steps. I guess gradient clipping / LR accounts for a lot of this?

@EugenHotaj yeah this gave me a bit of a scare, especially considering that we don't enable gradient clipping by default. Actually I believe that the behavior has to do with the optimizer: I used SGD instead of AdamW and manually hacked in a really high value for the grad scaler just to make sure nothing was broken. In that case the difference is very noticeable (see below). I didn't think that momentum would result in consistent loss curves when scaling grads up and down, but maybe I just need to refresh my memory on Adam a bit.

Screenshot 2025-01-07 at 4 48 55 PM

@EugenHotaj
Copy link
Contributor

I didn't think that momentum would result in consistent loss curves when scaling grads up and down, but maybe I just need to refresh my memory on Adam a bit.

@ebsmothers any chance we also need to do the same to adam momentum params when using FSDP? Pretty surprising to me as well that Adam would lead to identical learning curves

@mirceamironenco
Copy link
Contributor Author

mirceamironenco commented Jan 8, 2025

@EugenHotaj yeah this gave me a bit of a scare, especially considering that we don't enable gradient clipping by default. Actually I believe that the behavior has to do with the optimizer: I used SGD instead of AdamW and manually hacked in a really high value for the grad scaler just to make sure nothing was broken. In that case the difference is very noticeable (see below). I didn't think that momentum would result in consistent loss curves when scaling grads up and down, but maybe I just need to refresh my memory on Adam a bit.
Screenshot 2025-01-07 at 4 48 55 PM

Just to make sure I understand, if you only hack the grad scaler to be very large but keep AdamW, the loss curves are still basically identical? (IIUC you did both in this comparison?)

Maybe the loss curves being very similar is not so strange since the denominator will have a very large number of tokens compared to world_size, but some other ideas to battle test this more (I can implement these in a separate branch just for a comparison if you want):

  1. Hardcoding the reduce_op when wrapping with fully_shard as mentioned in an earlier version of the PR:
# Must be done for each sharded module.
module = fully_shard(
    module,
    mesh=mesh,
    reshard_after_forward=reshard_after_forward,
    shard_placement_fn=shard_placement_fn,
    mp_policy=mp_policy,
    offload_policy=offload_policy,
)
# Change the reduce op manually
fsdp_param_group = fully_shard.state(module)._fsdp_param_group
fsdp_param_group.reduce_scatter_reduce_op = ReduceOp.SUM

this happens before the optimizer is initialized, in case anything is happening there.

  1. Implementing a DDP variant (you could also try it with reshard_after_fwd=False, but this would only be equivalent to ZeRO-2).

  2. Same experiment you did but with very high/very low learning rate.

Potentially getting some feedback from the FSDP2 authors just as a sanity check could be useful.

@ebsmothers
Copy link
Contributor

Just to make sure I understand, if you only hack the grad scaler to be very large but keep AdamW, the loss curves are still basically identical? (IIUC you did both in this comparison?)

@mirceamironenco Yeah this is correct. Re your suggestions, (3) was the first one that came to my mind (also conveniently the easiest 😃) so I gave that a try on our distributed LoRA recipe. The below plot is the result of running AdamW with a much higher LR of 0.01, you can see that the two loss curves diverge (also unsurprisingly the loss blows up):

Screenshot 2025-01-08 at 8 38 31 AM

But the point is that scaling the gradients can result in different loss curves with AdamW, it just doesn't really show up under our baseline configs (which I suppose is a good thing with respect to the impact of this whole world-size-scaling bug).

Can also tag in our resident optimizer expert @janeyx99 in case she has any thoughts. TLDR for Jane is that we manually scale grads using this utility just before optimizer step, but surprisingly even scaling by a pretty large amount doesn't really mess with our loss curves when using AdamW (while for SGD there is a noticeable impact).

@janeyx99
Copy link
Contributor

janeyx99 commented Jan 8, 2025

Not sure how helpful this is, but yes, I'm not surprised the gradient changes didn't affect AdamW as much as it did SGD. The SGD update is very gradient-dependent:
image

whereas the Adam(W) update is scaled by momentum over rt(variance), which is like scaling by g/rt(g^2) with all minutia stripped away:
image

@ebsmothers
Copy link
Contributor

ebsmothers commented Jan 9, 2025

Thanks @janeyx99! This is very helpful. Also I clearly should've just dug up the Adam paper. Direct quote:

Assuming $\epsilon= 0$, the effective step taken in parameter space at timestep $t$ is $$\Delta_t = \alpha * \hat{m}_t / \sqrt{\hat{v}_t}$$.
...
The effective stepsize $$\Delta_t$$ is also invariant to the scale of the gradients; rescaling the gradients with factor $c$ will scale $$\hat{m}_t$$ with a factor $c$ and $$\hat{v}_t$$ with a factor $c^2$, which will cancel out: $$(c \cdot \hat{m}_t) / (\sqrt{c^2 \cdot \hat{v}_t}) = \hat{m}_t / \sqrt{\hat{v}_t}$$.

(Please excuse my sloppy LaTeX, I swear I was good at this once..) Also thanks @mirceamironenco for mentioning this over chat and forcing me to dig it up. So actually I think we are good here -- in fact I'm no longer even worried about breaking BC with this change after actually having done my homework.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding and fixing the bug @mirceamironenco! And thanks for your patience while we sorted out the whole Adam grad scaling thing in review. Based on our discussion, I think this is good to go.

@ebsmothers ebsmothers merged commit e420bc0 into pytorch:main Jan 9, 2025
17 checks passed
@mirceamironenco mirceamironenco deleted the fix-world-size-normalization branch January 10, 2025 08:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants