diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index a95d9306f0..2c0b0fc735 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -19,10 +19,10 @@ Modeling Components and Building Blocks RMSNorm Fp32LayerNorm TanhGate + TiedLinear TransformerSelfAttentionLayer TransformerCrossAttentionLayer TransformerDecoder - TiedEmbeddingTransformerDecoder VisionTransformer Base Tokenizers diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 85e9c6577c..35516eefff 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -1,5 +1,5 @@ # Config for multi-device LoRA finetuning in lora_finetune_distributed.py -# using a Qwen2 0.5B model +# using a Qwen2 1.5B model # # This config assumes that you've run the following command before launching # this run: @@ -27,18 +27,18 @@ model: tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer - path: /tmp/Qwen2-0.5B-Instruct/vocab.json - merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + path: /tmp/Qwen2-1.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-1.5B-Instruct/merges.txt max_seq_len: null checkpointer: _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct checkpoint_files: [ model.safetensors ] recipe_checkpoint: null - output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune + output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune model_type: QWEN2 resume_from_checkpoint: False @@ -67,7 +67,7 @@ max_steps_per_epoch: null gradient_accumulation_steps: 8 # Logging -output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune +output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} diff --git a/torchtune/models/qwen2/_component_builders.py b/torchtune/models/qwen2/_component_builders.py index 734aae1423..716fe337ad 100644 --- a/torchtune/models/qwen2/_component_builders.py +++ b/torchtune/models/qwen2/_component_builders.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. from functools import partial -from typing import List, Union +from typing import List from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torch import nn - -from torchtune.modules.transformer import TransformerDecoder, TiedEmbeddingTransformerDecoder +from torchtune.modules.transformer import TransformerDecoder from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings from torchtune.modules import ( @@ -18,6 +17,7 @@ FeedForward, RMSNorm, TransformerSelfAttentionLayer, + TiedLinear ) @@ -48,7 +48,7 @@ def qwen2( norm_eps: float = 1e-5, rope_base: float = 1_000_000.0, tie_word_embeddings: bool = False, -) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]: +) -> TransformerDecoder: """ Build the decoder associated with the Qwen2 model. This includes: - Token embeddings @@ -104,28 +104,20 @@ def qwen2( mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps), ) tok_embeddings = nn.Embedding(vocab_size, embed_dim) - output_proj = None if tie_word_embeddings else nn.Linear(embed_dim, vocab_size, bias=False) - if output_proj is None: - return TiedEmbeddingTransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=head_dim, - norm=RMSNorm(embed_dim, eps=norm_eps), - ) + if tie_word_embeddings: + output_proj = TiedLinear(tok_embeddings) else: - return TransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=head_dim, - norm=RMSNorm(embed_dim, eps=norm_eps), - output=output_proj, - ) + output_proj = nn.Linear(embed_dim, vocab_size, bias=False) + return TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) def qwen2_mlp(dim: int, hidden_dim: int) -> FeedForward: @@ -162,7 +154,7 @@ def lora_qwen2( use_dora: bool = False, # Quantization args quantize_base: bool = False, -) -> Union[TransformerDecoder, TiedEmbeddingTransformerDecoder]: +) -> TransformerDecoder: """ Return a version of Qwen2 (an instance of :func:`~torchtune.models.qwen2.transformer.Qwen2TransformerDecoder`) with LoRA applied based on the passed in configuration. @@ -251,7 +243,7 @@ def lora_qwen2( "apply_lora_to_output is incompatible with tie_word_embeddings," " as there would be no output to apply lora to!" ) - output_proj = None + output_proj = TiedLinear(tok_embeddings) else: # TODO: quantize_base is not applied to final output_proj currently. adapter_cls = DoRALinear if use_dora else LoRALinear @@ -260,27 +252,16 @@ def lora_qwen2( if apply_lora_to_output else nn.Linear(embed_dim, vocab_size, bias=False) ) - if output_proj is None: - model = TiedEmbeddingTransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=(embed_dim // num_heads), - norm=RMSNorm(embed_dim, eps=norm_eps), - ) - else: - model = TransformerDecoder( - tok_embeddings=tok_embeddings, - layers=layer, - num_layers=num_layers, - max_seq_len=max_seq_len, - num_heads=num_heads, - head_dim=(embed_dim // num_heads), - norm=RMSNorm(embed_dim, eps=norm_eps), - output=output_proj, - ) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layer, + num_layers=num_layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=(embed_dim // num_heads), + norm=RMSNorm(embed_dim, eps=norm_eps), + output=output_proj, + ) if quantize_base: # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly diff --git a/torchtune/models/qwen2/_model_builders.py b/torchtune/models/qwen2/_model_builders.py index 1ec9cc53e5..e8e7334d57 100644 --- a/torchtune/models/qwen2/_model_builders.py +++ b/torchtune/models/qwen2/_model_builders.py @@ -7,7 +7,7 @@ from torchtune.models.qwen2._component_builders import qwen2, lora_qwen2 from torchtune.models.qwen2._tokenizer import Qwen2Tokenizer -from torchtune.modules import TransformerDecoder, TiedEmbeddingTransformerDecoder +from torchtune.modules import TransformerDecoder from torchtune.modules.peft import LORA_ATTN_MODULES from torchtune.modules.tokenizers import parse_hf_tokenizer_json from torchtune.data._prompt_templates import _TemplateType @@ -42,17 +42,17 @@ def qwen2_7b() -> TransformerDecoder: ) -def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder: +def qwen2_0_5b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 0.5B parameter values from https://huggingface.co/Qwen/Qwen2-0.5B-Instruct Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model + TransformerDecoder: Instantiation of Qwen2 0.5B model Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return qwen2( vocab_size=151936, @@ -69,17 +69,17 @@ def qwen2_0_5b() -> TiedEmbeddingTransformerDecoder: ) -def qwen2_1_5b() -> TiedEmbeddingTransformerDecoder: +def qwen2_1_5b() -> TransformerDecoder: """ Builder for creating a Qwen2 model initialized w/ the default 1.5B parameter values from https://huggingface.co/Qwen/Qwen2-1.5B-Instruct Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model + TransformerDecoder: Instantiation of Qwen2 1.5B model Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return qwen2( vocab_size=151936, @@ -191,7 +191,7 @@ def lora_qwen2_0_5b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> TiedEmbeddingTransformerDecoder: +) -> TransformerDecoder: """ Builder for creating a Qwen2 0.5B model with LoRA enabled. @@ -211,11 +211,11 @@ def lora_qwen2_0_5b( quantize_base (bool): Whether to quantize base model weights Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied + TransformerDecoder: Instantiation of Qwen2 0.5B model with LoRA applied Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return lora_qwen2( lora_attn_modules=lora_attn_modules, @@ -248,7 +248,7 @@ def lora_qwen2_1_5b( lora_dropout: float = 0.0, use_dora: bool = False, quantize_base: bool = False, -) -> TiedEmbeddingTransformerDecoder: +) -> TransformerDecoder: """ Builder for creating a Qwen2 1.5B model with LoRA enabled. @@ -268,11 +268,11 @@ def lora_qwen2_1_5b( quantize_base (bool): Whether to quantize base model weights Returns: - TiedEmbeddingTransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied + TransformerDecoder: Instantiation of Qwen2 1.5B model with LoRA applied Note: Qwen2 0.5B and Qwen2 1.5B model builders will enable `tie_word_embeddings` by default - and returns an instance of `TiedEmbeddingTransformerDecoder`. + and returns an instance of `TransformerDecoder`. """ return lora_qwen2( lora_attn_modules=lora_attn_modules, diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 695a7c6f1f..66076d52f9 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -15,6 +15,7 @@ from .position_embeddings import RotaryPositionalEmbeddings # noqa from .rms_norm import RMSNorm # noqa from .tanh_gate import TanhGate # noqa +from .tied_linear import TiedLinear # noqa from .transformer import ( # noqa TiedEmbeddingTransformerDecoder, TransformerCrossAttentionLayer, @@ -32,6 +33,7 @@ "KVCache", "RotaryPositionalEmbeddings", "RMSNorm", + "TiedLinear", "Fp32LayerNorm", "VisionTransformer", "TransformerDecoder", diff --git a/torchtune/modules/tied_linear.py b/torchtune/modules/tied_linear.py new file mode 100644 index 0000000000..6864f6fa1b --- /dev/null +++ b/torchtune/modules/tied_linear.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TiedLinear: + """ + A tied linear layer, without bias, that shares the same weight as another linear layer. + This is useful for models that use tied weights, such as qwen and gemma. + It requires as input an nn.Module, instead of the weight of the module, so it + can work with FSDP. Otherwise, the memory reference will be lost after FSDP is applied. + + Args: + tied_module (nn.Module): The module whose weight is shared. Only + the weight is used. The bias is ignored. + Raises: + AttributeError: If the provided module does not have an attribute 'weight'. + """ + + def __init__(self, tied_module: nn.Module): + self.tied_module = tied_module + if not hasattr(tied_module, "weight"): + raise AttributeError( + "Provided module does not have attribute 'weight'. Please check your tied_module." + ) + + def __call__(self, x: torch.tensor) -> torch.tensor: + return F.linear(x, self.tied_module.weight) diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 65e511e5ea..714924c69d 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -4,14 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from torch import nn from torchtune.modules import MultiHeadAttention - from torchtune.modules.attention_utils import _MaskType +from torchtune.utils.logging import deprecated class TransformerSelfAttentionLayer(nn.Module): @@ -295,7 +295,7 @@ class TransformerDecoder(nn.Module): to setup the :func:`~torchtune.modules.KVCache` norm (nn.Module): Callable that applies normalization to the output of the decoder, before final MLP. - output (nn.Linear): Callable that applies a linear transformation to the output of + output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of the decoder. num_layers (Optional[int]): Number of Transformer Decoder layers, only define when layers is not a list. @@ -320,7 +320,7 @@ def __init__( num_heads: int, head_dim: int, norm: nn.Module, - output: nn.Linear, + output: Union[nn.Linear, Callable], num_layers: Optional[int] = None, output_hidden_states: Optional[List[int]] = None, ) -> None: @@ -516,6 +516,11 @@ def forward( return output +@deprecated( + msg="Please use torchtune.modules.TransformerDecoder instead. \ +If you need an example, see torchtune.models.qwen2._component_builders.py \ +and how to implement torch.modules.TiedLinear for the output projection." +) class TiedEmbeddingTransformerDecoder(nn.Module): """ Transformer Decoder with tied embedding weight. A key difference between