Introduction | Installation | Get Started | Documentation | Community | License | Citing torchtune
- November 2024: torchtune has released v0.4.0 which includes stable support for exciting features like activation offloading and multimodal QLoRA
- November 2024: torchtune has added Gemma2 to its models!
- October 2024: torchtune added support for Qwen2.5 models - find the recipes here
- September 2024: torchtune has support for Llama 3.2 11B Vision, Llama 3.2 3B, and Llama 3.2 1B models! Try them out by following our installation instructions here, then run any of the text configs here or vision configs here.
torchtune is a PyTorch library for easily authoring, finetuning and experimenting with LLMs.
torchtune provides:
- PyTorch implementations of popular LLMs from Llama, Gemma, Mistral, Phi, and Qwen model families
- Hackable training recipes for full finetuning, LoRA, QLoRA, DPO, PPO, QAT, knowledge distillation, and more
- Out-of-the-box memory efficiency, performance improvements, and scaling with the latest PyTorch APIs
- YAML configs for easily configuring training, evaluation, quantization or inference recipes
- Built-in support for many popular dataset formats and prompt templates
torchtune currently supports the following models.
Model | Sizes |
---|---|
Llama3.2-Vision | 11B, 90B [models, configs] |
Llama3.2 | 1B, 3B [models, configs] |
Llama3.1 | 8B, 70B, 405B [models, configs] |
Llama3 | 8B, 70B [models, configs] |
Llama2 | 7B, 13B, 70B [models, configs] |
Code-Llama2 | 7B, 13B, 70B [models, configs] |
Mistral | 7B [models, configs] |
Gemma | 2B, 7B [models, configs] |
Gemma2 | 2B, 9B, 27B [models, configs] |
Microsoft Phi3 | Mini [models, configs] |
Qwen2 | 0.5B, 1.5B, 7B [models, configs] |
Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B [models, configs] |
We're always adding new models, but feel free to file an issue if there's a new one you would like to see in torchtune.
torchtune provides the following finetuning recipes for training on one or more devices.
Finetuning Method | Devices | Recipe | Example Config(s) |
---|---|---|---|
Full Finetuning | 1-8 | full_finetune_single_device full_finetune_distributed |
Llama3.1 8B single-device Llama 3.1 70B distributed |
LoRA Finetuning | 1-8 | lora_finetune_single_device lora_finetune_distributed |
Qwen2 0.5B single-device Gemma 7B distributed |
QLoRA Finetuning | 1-8 | lora_finetune_single_device lora_finetune_distributed |
Phi3 Mini single-device Llama 3.1 405B distributed |
DoRA/QDoRA Finetuning | 1-8 | lora_finetune_single_device lora_finetune_distributed |
Llama3 8B QDoRA single-device Llama3 8B DoRA distributed |
Quantization-Aware Training | 2-8 | qat_distributed | Llama3 8B QAT |
Quantization-Aware Training and LoRA Finetuning | 2-8 | qat_lora_finetune_distributed | Llama3 8B QAT |
Direct Preference Optimization | 1-8 | lora_dpo_single_device lora_dpo_distributed |
Llama2 7B single-device Llama2 7B distributed |
Proximal Policy Optimization | 1 | ppo_full_finetune_single_device | Mistral 7B |
Knowledge Distillation | 1 | knowledge_distillation_single_device | Qwen2 1.5B -> 0.5B |
The above configs are just examples to get you started. If you see a model above not listed here, we likely still support it. If you're unsure whether something is supported, please open an issue on the repo.
Below is an example of the memory requirements and training speed for different Llama 3.1 models.
Note
For ease of comparison, all the below numbers are provided for batch size 2 (without gradient accumulation), a dataset packed to sequence length 2048, and torch compile enabled.
If you are interested in running on different hardware or with different models, check out our documentation on memory optimizations here to find the right setup for you.
Model | Finetuning Method | Runnable On | Peak Memory per GPU | Tokens/sec * |
---|---|---|---|---|
Llama 3.1 8B | Full finetune | 1x 4090 | 18.9 GiB | 1650 |
Llama 3.1 8B | Full finetune | 1x A6000 | 37.4 GiB | 2579 |
Llama 3.1 8B | LoRA | 1x 4090 | 16.2 GiB | 3083 |
Llama 3.1 8B | LoRA | 1x A6000 | 30.3 GiB | 4699 |
Llama 3.1 8B | QLoRA | 1x 4090 | 7.4 GiB | 2413 |
Llama 3.1 70B | Full finetune | 8x A100 | 13.9 GiB ** | 1568 |
Llama 3.1 70B | LoRA | 8x A100 | 27.6 GiB | 3497 |
Llama 3.1 405B | QLoRA | 8x A100 | 44.8 GB | 653 |
*= Measured over one full training epoch
**= Uses CPU offload with fused optimizer
torchtune exposes a number of levers for memory efficiency and performance. The table below demonstrates the effects of applying some of these techniques sequentially to the Llama 3.2 3B model. Each technique is added on top of the previous one, except for LoRA and QLoRA, which do not use optimizer_in_bwd
or AdamW8bit
optimizer.
Baseline:
- Model: Llama 3.2 3B
- Batch size: 2
- Max seq len: 4096
- Precision: bf16
- Hardware: A100
- Recipe: full_finetune_single_device
Technique | Peak Memory Active (GiB) | % Change Memory vs Previous | Tokens Per Second | % Change Tokens/sec vs Previous |
---|---|---|---|---|
Baseline | 25.5 | - | 2091 | - |
+ Packed Dataset | 60.0 | +135.16% | 7075 | +238.40% |
+ Compile | 51.0 | -14.93% | 8998 | +27.18% |
+ Chunked Cross Entropy | 42.9 | -15.83% | 9174 | +1.96% |
+ Activation Checkpointing | 24.9 | -41.93% | 7210 | -21.41% |
+ Fuse optimizer step into backward | 23.1 | -7.29% | 7309 | +1.38% |
+ Activation Offloading | 21.8 | -5.48% | 7301 | -0.11% |
+ 8-bit AdamW | 17.6 | -19.63% | 6960 | -4.67% |
LoRA | 8.5 | -51.61% | 8210 | +17.96% |
QLoRA | 4.6 | -45.71% | 8035 | -2.13% |
The final row in the table vs baseline + Packed Dataset uses 81.9% less memory with a 284.3% increase in tokens per second. It can be run via the command:
tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device \
dataset.packed=True \
compile=True \
loss=torchtune.modules.loss.CEWithChunkedOutputLoss \
enable_activation_checkpointing=True \
optimizer_in_bwd=False \
enable_activation_offloading=True \
optimizer._component_=torch.optim.AdamW \
tokenizer.max_seq_len=4096 \
gradient_accumulation_steps=1 \
epochs=1 \
batch_size=2
torchtune is tested with the latest stable PyTorch release as well as the preview nightly version. torchtune leverages torchvision for finetuning multimodal LLMs and torchao for the latest in quantization techniques; you should install these as well.
# Install stable PyTorch, torchvision, torchao stable releases
pip install torch torchvision torchao
pip install torchtune
# Install PyTorch, torchvision, torchao nightlies
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
pip install --pre --upgrade torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu
You can also check out our install documentation for more information, including installing torchtune from source.
To confirm that the package is installed correctly, you can run the following command:
tune --help
And should see the following output:
usage: tune [-h] {ls,cp,download,run,validate} ...
Welcome to the torchtune CLI!
options:
-h, --help show this help message and exit
...
To get started with torchtune, see our First Finetune Tutorial. Our End-to-End Workflow Tutorial will show you how to evaluate, quantize and run inference with a Llama model. The rest of this section will provide a quick overview of these steps with Llama3.1.
Follow the instructions on the official meta-llama
repository to ensure you have access to the official Llama model weights. Once you have confirmed access, you can run the following command to download the weights to your local machine. This will also download the tokenizer model and a responsible use guide.
To download Llama3.1, you can run:
tune download meta-llama/Meta-Llama-3.1-8B-Instruct \
--output-dir /tmp/Meta-Llama-3.1-8B-Instruct \
--ignore-patterns "original/consolidated.00.pth" \
--hf-token <HF_TOKEN> \
Tip
Set your environment variable HF_TOKEN
or pass in --hf-token
to the command in order to validate your access. You can find your token at https://huggingface.co/settings/tokens
You can finetune Llama3.1 8B with LoRA on a single GPU using the following command:
tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
For distributed training, tune CLI integrates with torchrun. To run a full finetune of Llama3.1 8B on two GPUs:
tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full
Tip
Make sure to place any torchrun commands before the recipe specification. Any CLI args after this will override the config and not impact distributed training.
There are two ways in which you can modify configs:
Config Overrides
You can directly overwrite config fields from the command line:
tune run lora_finetune_single_device \
--config llama2/7B_lora_single_device \
batch_size=8 \
enable_activation_checkpointing=True \
max_steps_per_epoch=128
Update a Local Copy
You can also copy the config to your local directory and modify the contents directly:
tune cp llama3_1/8B_full ./my_custom_config.yaml
Copied to ./my_custom_config.yaml
Then, you can run your custom recipe by directing the tune run
command to your local files:
tune run full_finetune_distributed --config ./my_custom_config.yaml
Check out tune --help
for all possible CLI commands and options. For more information on using and updating configs, take a look at our config deep-dive.
torchtune supports finetuning on a variety of different datasets, including instruct-style, chat-style, preference datasets, and more. If you want to learn more about how to apply these components to finetune on your own custom dataset, please check out the provided links along with our API docs.
torchtune focuses on integrating with popular tools and libraries from the ecosystem. These are just a few examples, with more under development:
- Hugging Face Hub for accessing model weights
- EleutherAI's LM Eval Harness for evaluating trained models
- Hugging Face Datasets for access to training and evaluation datasets
- PyTorch FSDP2 for distributed training
- torchao for lower precision dtypes and post-training quantization techniques
- Weights & Biases for logging metrics and checkpoints, and tracking training progress
- Comet as another option for logging
- ExecuTorch for on-device inference using finetuned models
- bitsandbytes for low memory optimizers for our single-device recipes
- PEFT for continued finetuning or inference with torchtune models in the Hugging Face ecosystem
We really value our community and the contributions made by our wonderful users. We'll use this section to call out some of these contributions. If you'd like to help out as well, please see the CONTRIBUTING guide.
- @SalmanMohammadi for adding a comprehensive end-to-end recipe for Reinforcement Learning from Human Feedback (RLHF) finetuning with PPO to torchtune
- @fyabc for adding Qwen2 models, tokenizer, and recipe integration to torchtune
- @solitude-alive for adding the Gemma 2B model to torchtune, including recipe changes, numeric validations of the models and recipe correctness
- @yechenzhi for adding Direct Preference Optimization (DPO) to torchtune, including the recipe and config along with correctness checks
- @Optimox for adding all the Gemma2 variants to torchtune!
The Llama2 code in this repository is inspired by the original Llama2 code.
We want to give a huge shout-out to EleutherAI, Hugging Face and Weights & Biases for being wonderful collaborators and for working with us on some of these integrations within torchtune.
We also want to acknowledge some awesome libraries and tools from the ecosystem:
- gpt-fast for performant LLM inference techniques which we've adopted out-of-the-box
- llama recipes for spring-boarding the llama2 community
- bitsandbytes for bringing several memory and performance based techniques to the PyTorch ecosystem
- @winglian and axolotl for early feedback and brainstorming on torchtune's design and feature set.
- lit-gpt for pushing the LLM finetuning community forward.
- HF TRL for making reward modeling more accessible to the PyTorch community.
torchtune is released under the BSD 3 license. However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models.
If you find the torchtune library useful, please cite it in your work as below.
@software{torchtune,
title = {torchtune: PyTorch's finetuning library},
author = {torchtune maintainers and contributors},
url = {https//github.com/pytorch/torchtune},
license = {BSD-3-Clause},
month = apr,
year = {2024}
}