diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 7e5f0cf5d9..66076d52f9 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -33,7 +33,8 @@ "KVCache", "RotaryPositionalEmbeddings", "RMSNorm", - "TiedLinear" "Fp32LayerNorm", + "TiedLinear", + "Fp32LayerNorm", "VisionTransformer", "TransformerDecoder", "TiedEmbeddingTransformerDecoder", diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 6c0e65b811..714924c69d 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -518,7 +518,8 @@ def forward( @deprecated( msg="Please use torchtune.modules.TransformerDecoder instead. \ -If you need an example, see torchtune.models.qwen2._component_builders.py" +If you need an example, see torchtune.models.qwen2._component_builders.py \ +and how to implement torch.modules.TiedLinear for the output projection." ) class TiedEmbeddingTransformerDecoder(nn.Module): """