Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update lr_schedules.py #4563

Merged
merged 14 commits into from
Nov 10, 2023
Merged

Update lr_schedules.py #4563

merged 14 commits into from
Nov 10, 2023

Conversation

CoinCheung
Copy link
Contributor

add cosine annealing scheduler

this scheduler is widely used in image classification task, and many llm (e.g. llama) use this also.

add cosine annealing scheduler
@CoinCheung
Copy link
Contributor Author

@microsoft-github-policy-service agree

@tjruwase
Copy link
Contributor

@CoinCheung, thanks for the PR. A few items to address.

To fix formatting issues use this guide.

Please add unit test: example

Inspect failing CI tests.

@CoinCheung CoinCheung requested a review from mrwyattii as a code owner October 26, 2023 03:41
@CoinCheung
Copy link
Contributor Author

@tjruwase @wjessup @dfyz @manuelciosici I have no experience with pr for deepspeed, what is the status of this now? Is there any further operation that needs me to work on?

@tjruwase
Copy link
Contributor

@CoinCheung, thanks for making the changes. We will review and merge once the CI passes.

@CoinCheung
Copy link
Contributor Author

Hi @tjruwase ,

I have made some fixes, would you please help me launch CI test one more time?

@CoinCheung
Copy link
Contributor Author

@tjruwase Would you please launch CI one more time ?

@CoinCheung
Copy link
Contributor Author

Hi @jeffra @mrwyattii I think the problem is not with my fix, it is a inference error, but my fix is about training learning rate scheduler. Can this fix be merged ? Or is there other things that need me to commit?
image

@tjruwase
Copy link
Contributor

@CoinCheung, sorry for the delay. It seems the issue is with our CI system. Please bear with us while we resolve the problem.

@tjruwase tjruwase added this pull request to the merge queue Oct 31, 2023
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 31, 2023
@CoinCheung
Copy link
Contributor Author

Hi @tjruwase , What is the status of this thread?
image

@tjruwase
Copy link
Contributor

tjruwase commented Nov 7, 2023

@CoinCheung, I have restarted CI. Let's see how it goes.

@CoinCheung
Copy link
Contributor Author

Hi @tjruwase ,

Is this associated with my changes?
image

@tjruwase
Copy link
Contributor

tjruwase commented Nov 8, 2023

@CoinCheung, no I don't think it is related to your changes.

@CoinCheung CoinCheung requested a review from tjruwase November 10, 2023 01:11
@tjruwase tjruwase added this pull request to the merge queue Nov 10, 2023
Merged via the queue into microsoft:master with commit 4388a60 Nov 10, 2023
15 checks passed
@kmn1024
Copy link

kmn1024 commented Nov 16, 2023

kmn1024 added a commit to kmn1024/axolotl that referenced this pull request Nov 16, 2023
@tjruwase
Copy link
Contributor

Should WarmupCosineLR inherit from WarmupLR? https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/lr_schedules.py#L774

Yes, you are correct. It should.

@CoinCheung, are you able to refactor your changes? Thanks!

@CoinCheung
Copy link
Contributor Author

CoinCheung commented Nov 17, 2023

Hi @tjruwase @kmn1024 , I do not think WarmupCosineLR can be interited from this WarmupLR in this case.

SInce they use different methods to determine the learning rates. For WarmupCosineLR, I use "ratio of original lr values", which I think should be more scientific, while WarmupLR uses specific lr values.

For example, when using WarmupLR, by setting warmup_min_lr=1e-5, warmup_num_steps=100, the scheduler will set lr from 1e-5 to max lr within 100 steps.
When using WarmupCosineLR, by setting warmup_min_ratio=0.1, warmup_num_steps=100, and we assume lr=1e-3 when we defined the optimizer, the scheduler would set lr from 0,1 * lr = 1e-4 to max lr within 100 steps. We do not set specific lr values to a scheduler.

The reason why I feel using ratio is better: we do not need to set specific lr values everywhere in both optimizer and scheduler. When we define an optimizer, we need to consider the learning rate. When we define a scheduler, what we only need to do is to determine the shape of the learning rate curves, rather than its specific values. When you want to keep shape of lr curve and only tune peak lr, you only need to change one place. This follows the principle of "each module only does its own work, and their settings are not impacted by each other".

From my experience of tuning models, this method is less likely to cause mistakes that I change optimizer lr but forgot to change scheduler.

Also in some paper, if I recally correctly, they claimed that they use CosineLR to train their model, and the learning rate anneals from max_ lr to 0.1 * max_lr. I think many other people accept this method of tuning learning rates.

@tjruwase
Copy link
Contributor

tjruwase commented Nov 17, 2023

@CoinCheung, thanks for your response. I agree with the differences that you identify between WarmupLR and WarmupCosineLR, but these differences are to me simply in the implementation and logic. At the high-level they are similar because of they provide two phases of lr changes: (1) initial phase of warmup/increase, and (2) final phase of no change or decay. Looking more closely we observe significant similarity or duplication in many of the methods including step, state_dict, load_state_dict, get_last_lr, _format_param. These similarities suggest to me opportunities for to code refactor and reuse.

@CoinCheung
Copy link
Contributor Author

CoinCheung commented Nov 17, 2023

@tjruwase Would it be acceptable if we change args (init args used for define the scheduler object) of WarmupLR ? It has only one sub-class WarmupDecayLR, and I think its usage frequency is not very high.

mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
add cosine annealing scheduler

this scheduler is widely used in image classification task, and many llm
(e.g. llama) use this also.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants