Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set gloo process group for FSDP with CPU offload #2108

Merged
merged 4 commits into from
Dec 4, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Dec 3, 2024

Addresses #1977.

As discussed in the issue (see this comment), FSDP's implementation of gradient clipping uses _NormPartial, which requires comms primitives (specifically all_reduce). This means that when running with CPU offloading we need to initialize the gloo process group to calculate the grad norm for DTensors on CPU. For simplicity this PR enables it whenever CPU offloading is enabled regardless of gradient clipping.

Test plan:

Added a test case for gradient clipping + CPU offload to test_full_finetune_distributed.py.

pytest -m integration_test tests/recipes/test_full_finetune_distributed.py -k 'test_loss'
...
========= 3 passed, 1 deselected in 43.54s ========

Copy link

pytorch-bot bot commented Dec 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2108

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e4f00c4 with merge base 32e265d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 3, 2024
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
process_group = "cuda:nccl,cpu:gloo"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: why not just make this the default every time, instead of "gloo" if cfg.device == "cpu" else "nccl"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wasn't sure myself why we do this. Well now I know. So yes, I think we can take your suggestion here

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be low risk, since it was tested in the dcp PR already. Approving to unblock.

@ebsmothers ebsmothers merged commit 5eb04cd into pytorch:main Dec 4, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants