Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Early Exit Loss and/or Layer Dropout #1076

Merged
merged 90 commits into from
Dec 6, 2024

Conversation

mostafaelhoushi
Copy link
Contributor

@mostafaelhoushi mostafaelhoushi commented Jun 9, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature

Implements early exit loss and/pr layer dropout to reproduce experiments in various papers like LayerSkip, LITE, LayerDrop, and Progressive Layer Dropping.

Usage

Download Llama2 7B model:

tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>

To reproduce experiments of various papers that use early exit loss and/or layer dropout:

tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp
tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5
tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2
tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp

Changelog

  • Modified TransformerDecoder by refactoring the operations that happen after running all transformer layers that transform last hidden state to a probability vector across all tokens into a separate unembed() function. This was needed to re-use unembed() to obtain predictions at intermediate layers, and hence calculate losses at intermediate layers.
  • Added to torchtune/modules/common_utils.py a slice_str_to_array() function that can convert numpy-like indexing string into a boolean array.
  • Added torchtune/modules/early_exit_loss.py to implement early exit loss
  • Added torchtune/modules/layer_dropiut.py to implement layer dropout
  • Added separate recipe script and configuration file
  • Added unit test cases

Test plan

  1. Download Llama 2 7B:
tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf 
  1. Finetune on TOPv2 instruction set (a small dataset that we can quickly finetune with early exit loss). W
tune run --nnodes 1 --nproc_per_node 8 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml 

We will run different variants of this command in the Table below to show different speed ups.

  1. Convert trained checkpoint to HF naming:
mv /tmp/Llama-2-7b-hf/hf_model_0001_0.pt /tmp/Llama-2-7b-hf/pytorch_model-00001-of-00002.bin
mv /tmp/Llama-2-7b-hf/hf_model_0002_0.pt /tmp/Llama-2-7b-hf/pytorch_model-00002-of-00002.bin
  1. Install HF transformers nightly (to enable early exit speculative decoding as described here):
pip install git+https://github.com/huggingface/transformers.git
pip install accelerate
  1. Save this code snippet to a file named eval_hf.py

  2. Run eval_hf.py that benchmarks autoregressive decoding and early exit self-speculative decoding:

python eval_hf.py

You should get an output that looks like this:

[INST]Set alarm for 6am every day[\INST]Set alarm for 6 am every day [IN:CREATE_ALARMSet alarm for 6 am every day ]
Orig Time: 1.159
[INST]Set alarm for 6am every day[\INST]Set alarm for 6 am every day [IN:CREATE_ALARMSet alarm for 6 am every day ]
Layerskip Time: 0.500

Hence, LayerSkip's training recipe has enabled LayerSkip's (early exit self-speculative decoding) inference that leads to a speedup of 1.159/0.500 = 2.318x.

We ran different variants of training and obtained the results below. Please note that differences in generation time for different training configurations may be due to different number of tokens generated. Hence, to measure speedup, we need to look at the different generation times for the same training configuration:
Link to Results Spreadsheets

Checklist

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Future Work

In the future we can:

  • Implement early exit self-speculative decoding in torchtune/generation/_generation.py: _generation.py uses similar code to gpt-fast. We have already implemented early-exit self-speculative decoding on gpt-fast here so it should be straightforward to port to torchtune.
  • Enable users to add their own heads or adapter modules at early exits rather than using a shared head for all early exits (this can enable implementing a paper like Kangaroo).
  • Make early_exit_loss() function support generic loss preprocessing: different training scripts may perform different pre-processing of logits and labels. We need to refactor training scripts slightly to encapsulate such preprocessing into a function to invoke when calculating early exit loss.
  • Make early_exit_loss() function support generic loss functions: support other types of losses like KL divergence loss, or losses represented as lists (as is the case with model's num_output_chunks>0).

Copy link

pytorch-bot bot commented Jun 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1076

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 69f840c with merge base f3d8d3c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 9, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! I think this would be a great feature for us to support. Left a couple of questions on design, especially how we expose things in TransformerDecoder. Also, is the idea to eventually support the self-speculative decoding technique as well?

torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
torchtune/modules/layer_dropout.py Outdated Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
torchtune/modules/layer_dropout.py Outdated Show resolved Hide resolved
recipes/full_finetune_distributed.py Outdated Show resolved Hide resolved
torchtune/modules/layer_dropout.py Outdated Show resolved Hide resolved
@Eshcar
Copy link

Eshcar commented Jun 18, 2024

Hi @mostafaelhoushi @ebsmothers @AkshatSh
this feature is very interesting
I see 2 branches with the name layer skip in the repository:
SkipLayer
ak/layer_skip
Do they both aim the same goal of skipping layers or early-exit at some layer?
Can you please elaborate more about the motivation behind this feature and the design of the solution

Thanks,
Eshcar

@mostafaelhoushi
Copy link
Contributor Author

Also, is the idea to eventually support the self-speculative decoding technique as well?

I am not sure if you would like to support inference optimization techniques in torchtune, but we have implemented self-speculative decoding in gpt-fast here: pytorch-labs/gpt-fast@main...LayerSkip.

@mostafaelhoushi
Copy link
Contributor Author

Hi @mostafaelhoushi @ebsmothers @AkshatSh this feature is very interesting I see 2 branches with the name layer skip in the repository: SkipLayer ak/layer_skip Do they both aim the same goal of skipping layers or early-exit at some layer? Can you please elaborate more about the motivation behind this feature and the design of the solution

Thanks, Eshcar

I want to double check, I can't find those 2 branches in this torchtune repo. Are you referring to torchtune repo or another repo?

@Eshcar
Copy link

Eshcar commented Jun 20, 2024

Hi @mostafaelhoushi @ebsmothers @AkshatSh this feature is very interesting I see 2 branches with the name layer skip in the repository: SkipLayer ak/layer_skip Do they both aim the same goal of skipping layers or early-exit at some layer? Can you please elaborate more about the motivation behind this feature and the design of the solution
Thanks, Eshcar

I want to double check, I can't find those 2 branches in this torchtune repo. Are you referring to torchtune repo or another repo?

apologies, my mistake it is from a different repository: pytorch-labs
https://github.com/pytorch-labs/gpt-fast/tree/LayerSkip by @mostafaelhoushi
https://github.com/pytorch-labs/gpt-fast/tree/ak/layer_skip by @AkshatSh

is there a connection between the 2 branches and the current issue?
what is the motivation for skipping layers in the middle of a token generation?
If there is some paper about this I would be happy to read it

@mostafaelhoushi
Copy link
Contributor Author

apologies, my mistake it is from a different repository: pytorch-labs https://github.com/pytorch-labs/gpt-fast/tree/LayerSkip by @mostafaelhoushi https://github.com/pytorch-labs/gpt-fast/tree/ak/layer_skip by @AkshatSh

is there a connection between the 2 branches and the current issue? what is the motivation for skipping layers in the middle of a token generation? If there is some paper about this I would be happy to read it

https://github.com/pytorch-labs/gpt-fast/tree/LayerSkip is just a slightly cleaned up fork of https://github.com/pytorch-labs/gpt-fast/tree/ak/layer_skip
They implement the self-speculative decoding inference described in this paper: https://arxiv.org/abs/2404.16710

This PR we are on in torchtune is to implement the training recipe of the same paper.

Copy link
Contributor Author

@mostafaelhoushi mostafaelhoushi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ebsmothers for your detailed review. Just made a batch of commits and comments to address the feedback.

log = utils.get_logger("DEBUG")


class LossScaleType(str, Enum):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am personally fine to follow the Callable approach.
The enum approach was sort of the convention being followed in the internal codebase I was working on.
If the Callable approach is sort of the convention for torchtune codebase, I can switch to that. Let me look for some example in the torchtune codebase to follow.

log = utils.get_logger("DEBUG")


class LossScaleType(str, Enum):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to look for an example to follow but can't find. For example, how can layer_ids_to_loss_scales determine the scale type?

train_last_layer = cfg_early_exit_loss.get("include_last_layer", True)
verbose = cfg_early_exit_loss.get("verbose", False)

if cfg_early_exit_loss.curriculum:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 40b7987

train_last_layer = cfg_early_exit_loss.get("include_last_layer", True)
verbose = cfg_early_exit_loss.get("verbose", False)

if cfg_early_exit_loss.curriculum:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I wanted to use config.instantiate(..) function call used in other functions like _setup_data() but I didn't understand how it works.

tests/torchtune/modules/test_early_exit_loss.py Outdated Show resolved Hide resolved
torchtune/modules/early_exit_loss.py Outdated Show resolved Hide resolved
log = utils.get_logger("DEBUG")


class LossScaleType(str, Enum):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In test_layer_ids_to_loss_scales, I just made the PyTest fixture test all the possible values in the enum:
https://github.com/mostafaelhoushi/torchtune/blob/3567a2431dc601efd29220a1ed2826e440fd43be/tests/torchtune/modules/test_early_exit_loss.py#L78C1-L81C6

I think that's one advantage of using enums (to ensure that we test all possible values). In the other solution you described, will we be able to do something similar?

GRADUAL = "gradual"


# TODO: create a base curriculum class that can be used for other aspects, e.g., dropout, datasets, etc.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the base class is EarlyExitCurriculum. I am thinking that in the future we can have a base class (or Protocol) Curriculum that can be used for other aspects or modules (dropout, dataset mixture weights, sparsity, etc. ) to change throughout training.

torchtune/modules/layer_dropout.py Outdated Show resolved Hide resolved
torchtune/modules/layer_dropout.py Outdated Show resolved Hide resolved
Comment on lines +1053 to +1057
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is more as a note to myself but) we will need to incorporate the same set of changes from #2108 to get CPU offload + gradient clipping to work together here

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few more comments (mostly just responses to open threads). But I think this is just about there, really excited to see it land!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mostafaelhoushi for enabling early exit loss and layer dropout in torchtune! I'm really happy with how it turned out, this is one of my favorite features we've landed (especially given some of the more challenging design aspects you had to work through). Thanks for sticking through it, looking forward to seeing its applications in LayerSkip and elsewhere!

@ebsmothers ebsmothers merged commit f8563dd into pytorch:main Dec 6, 2024
17 checks passed
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 8, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 9, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 18, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <[email protected]>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <[email protected]>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: vancoyendall <[email protected]>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <[email protected]>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <[email protected]>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <[email protected]>

* guard ckpt imports (pytorch#2133)

Co-authored-by: Felipe Mello <[email protected]>

* [bug fix] add parents=True (pytorch#2136)

Co-authored-by: Felipe Mello <[email protected]>

* [bug fix] re-add model (pytorch#2135)

Co-authored-by: Felipe Mello <[email protected]>

* Update save sizes into GiB (pytorch#2143)

* [bug fix] remove config download when source is kaggle (pytorch#2144)

Co-authored-by: Felipe Mello <[email protected]>

* [fix] remove "with_suffix" (pytorch#2146)

Co-authored-by: Felipe Mello <[email protected]>

* DoRA fixes (pytorch#2139)



Co-authored-by: Mircea Mironenco <[email protected]>

* [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)

* Small readme, config updates (pytorch#2157)

* Using `FormattedCheckpointFiles` in configs (pytorch#2147)

* Move ``get_world_size_and_rank`` to utils (pytorch#2155)

* Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006)

Co-authored-by: Saurabh Mishra <[email protected]>

* torchdata integration - multi-dataset and streaming support (pytorch#1929)

* Allow higher version of lm-eval (pytorch#2165)

* Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167)

* [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164)

---------

Co-authored-by: Philip Bontrager <[email protected]>
Co-authored-by: ebsmothers <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Felipe Mello <[email protected]>
Co-authored-by: Joe Cummings <[email protected]>
Co-authored-by: Mostafa Elhoushi <[email protected]>
Co-authored-by: vancoyendall <[email protected]>
Co-authored-by: Mircea Mironenco <[email protected]>
Co-authored-by: salman <[email protected]>
Co-authored-by: Saurabh Mishra <[email protected]>
Co-authored-by: Saurabh Mishra <[email protected]>
Co-authored-by: Andrew Ho <[email protected]>
Co-authored-by: Eugen Hotaj <[email protected]>
rahul-sarvam pushed a commit to sarvamai/torchtune that referenced this pull request Dec 23, 2024
rahul-sarvam pushed a commit to sarvamai/torchtune that referenced this pull request Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants