Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gemma2 variants #1835

Merged
merged 12 commits into from
Nov 8, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ torchtune currently supports the following models.
| [Code-Llama2](https://ai.meta.com/blog/code-llama-large-language-model-coding/) | 7B, 13B, 70B [[models](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] |
| [Mistral](https://huggingface.co/mistralai) | 7B [[models](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] |
| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B, 7B [[models](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] |
| [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2) | 2B, 9B, 27B [[models](torchtune/models/gemma2/_model_builders.py), [configs](recipes/configs/gemma2/)] |
| [Microsoft Phi3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) | Mini [[models](torchtune/models/phi3/), [configs](recipes/configs/phi3/)]
| [Qwen2](https://qwenlm.github.io/blog/qwen2/) | 0.5B, 1.5B, 7B [[models](torchtune/models/qwen2/), [configs](recipes/configs/qwen2/)]

Expand Down
31 changes: 31 additions & 0 deletions docs/source/api_ref_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,37 @@ To download the Gemma 7B model:
gemma.gemma_tokenizer


gemma2 :
--------

Models of size 2B, 9B, 27B from the `Gemma family <https://blog.google/technology/developers/gemma-open-models/>`_.

Important: You need to request access on `Hugging Face <https://huggingface.co/google/gemma-2-2b>`__ to use this model.

To download the Gemma2 2B, 9B, 27B models :

.. code-block:: bash

tune download google/gemma-2-<MODEL_SIZE>b --ignore-patterns "gemma-2-<MODEL_SIZE>b.gguf" --hf-token <HF_TOKEN>


.. autosummary::
:toctree: generated/
:nosignatures:

gemma2.gemma2
gemma2.lora_gemma2
gemma2.gemma2_2b
gemma2.lora_gemma2_2b
gemma2.qlora_gemma2_2b
gemma2.gemma2_9b
gemma2.lora_gemma2_9b
gemma2.qlora_gemma2_9b
gemma2.gemma2_27b
gemma2.lora_gemma2_27b
gemma2.qlora_gemma2_27b
gemma.gemma_tokenizer

clip
-----

Expand Down
74 changes: 74 additions & 0 deletions recipes/configs/gemma2/27B_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a gemma2 27B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does 8bit work with distributed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should I do here? Change something or expect users to change parameters according to their hardware ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR

#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only when the model is being fine-tuned on 2+ GPUs.


# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma-2-27b/tokenizer.model

# Dataset
dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have packed: False in dataset, and log_peak_memory_stats: True and compile: False below, for every one of our configs.

Would it be annoying to ask if we could update these in the same way while we're here, please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done I have updated all the configs to match the other PR!

packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.gemma2.gemma_27b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_component_: torchtune.models.gemma2.gemma_27b
_component_: torchtune.models.gemma2.gemma2_27b


checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2-27b/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00024
recipe_checkpoint: null
output_dir: /tmp/gemma-2-27b
model_type: GEMMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False # pytorch compile, set to true for perf/memory improvement

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-gemma2-27b-finetune
log_every_n_steps: 1
log_peak_memory_stats: True
86 changes: 86 additions & 0 deletions recipes/configs/gemma2/27B_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
# using a gemma2 27B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only when the model is being fine-tuned on 2+ GPUs.


# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma-2-27b/tokenizer.model

# Dataset
dataset:
packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_27b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 64
lora_alpha: 128
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2-27b/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00024
recipe_checkpoint: null
output_dir: /tmp/gemma-2-27b/
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 4
epochs: 3
max_steps_per_epoch: null
gradient_accumulation_steps: 1
compile: False # pytorch compile, set to true for perf/memory improvement

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-gemma2-27b-lora
log_every_n_steps: 1
log_peak_memory_stats: True
112 changes: 112 additions & 0 deletions recipes/configs/gemma2/27B_lora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Config for multi-device LoRA finetuning in lora_finetune_single_device.py
# using a gemma2 27B model
#
# This config assumes that you've run the following command before launching
# this run (torchtune does not use gguf so you can ignore it to save time and space):
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
path: /tmp/gemma-2-27b/tokenizer.model

# Dataset
dataset:
packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_27b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: True
lora_rank: 8
lora_alpha: 16
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2-27b/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00024
recipe_checkpoint: null
output_dir: /tmp/gemma-2-27b/
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-5

lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 10

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 2
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8
compile: False # pytorch compile, set to true for perf/memory improvement

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-gemma2-27b-lora
log_every_n_steps: 1
log_peak_memory_stats: True

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
Loading
Loading