-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
1/n - remove TiedEmbeddingTransformerDecoder from qwen #1547
Changes from 13 commits
4690b9b
17e6d79
b0154b9
e27f736
051f472
002d67f
d430c1f
6adf19f
c6dd298
a55f9ae
b427bf5
f54904e
a0bd26b
b26b4fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class TiedLinear: | ||
""" | ||
A tied linear layer, without bias, that shares the same weight as another linear layer. | ||
This is useful for models that use tied weights, such as qwen and gemma. | ||
It requires as input an nn.Module, instead of the weight of the module, so it | ||
can work with FSDP. Otherwise, the memory reference will be lost after FSDP is applied. | ||
|
||
Args: | ||
tied_module (nn.Module): The module whose weight is shared. Only | ||
the weight is used. The bias is ignored. | ||
Raises: | ||
AttributeError: If the provided module does not have an attribute 'weight'. | ||
""" | ||
|
||
def __init__(self, tied_module: nn.Module): | ||
self.tied_module = tied_module | ||
if not hasattr(tied_module, "weight"): | ||
raise AttributeError( | ||
"Provided module does not have attribute 'weight'. Please check your tied_module." | ||
) | ||
|
||
def __call__(self, x: torch.tensor) -> torch.tensor: | ||
return F.linear(x, self.tied_module.weight) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,14 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
from typing import Dict, List, Optional, Union | ||
from typing import Callable, Dict, List, Optional, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torchtune.modules import MultiHeadAttention | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PS: I forgot to update the typehint of the TransformerDecoder, saying that output can now be a callable. To avoid rerunning tests, this will come in a followin gemma PR |
||
from torchtune.modules.attention_utils import _MaskType | ||
from torchtune.utils.logging import deprecated | ||
|
||
|
||
class TransformerSelfAttentionLayer(nn.Module): | ||
|
@@ -295,7 +295,7 @@ class TransformerDecoder(nn.Module): | |
to setup the :func:`~torchtune.modules.KVCache` | ||
norm (nn.Module): Callable that applies normalization to the output of the decoder, | ||
before final MLP. | ||
output (nn.Linear): Callable that applies a linear transformation to the output of | ||
output (Union[nn.Linear, Callable]): Callable that applies a linear transformation to the output of | ||
the decoder. | ||
num_layers (Optional[int]): Number of Transformer Decoder layers, only define when | ||
layers is not a list. | ||
|
@@ -320,7 +320,7 @@ def __init__( | |
num_heads: int, | ||
head_dim: int, | ||
norm: nn.Module, | ||
output: nn.Linear, | ||
output: Union[nn.Linear, Callable], | ||
num_layers: Optional[int] = None, | ||
output_hidden_states: Optional[List[int]] = None, | ||
) -> None: | ||
|
@@ -516,6 +516,10 @@ def forward( | |
return output | ||
|
||
|
||
@deprecated( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tell them to use TransformerDecoder WITH TiedLinear. |
||
msg="Please use torchtune.modules.TransformerDecoder instead. \ | ||
If you need an example, see torchtune.models.qwen2._component_builders.py" | ||
) | ||
class TiedEmbeddingTransformerDecoder(nn.Module): | ||
""" | ||
Transformer Decoder with tied embedding weight. A key difference between | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formatting weird?