-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add gemma2 variants #1835
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1835
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 53eed40 with merge base 57ab583 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this! Just took a quick and very non-exhaustive first pass to leave a few comments, will get back to it with a full review later today.
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Gemma2Attention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we support flex attention, which support soft capping, would it make sense to just force gemma2 users to use flex attention instead of implementing this module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flex Attention is only supported on A100 or better, right? I don't think we can make the assumption that our users will have that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello everyone,
I just pushed a new commit which includes all changes discussed with @ebsmothers.
I also implemented a flex attention version but I could not make it work properly.
The default implementation (not using FlexAttention) seems to be working (I only launched the single lora pipeline, please see the attached logs log_gemma2-2b-single-lora_1729498141.txt).
I would appreciate some help on the FlexAttention implementation. Here is why I am struggling.
If I run the following code on my A6000 GPU with torch 2.5:
import torch
from torch.nn.attention.flex_attention import (
create_block_mask,
flex_attention)
WINDOW_SIZE=None #None
CAPPING=50.
SCALE=12.
def get_gemma2_flex_score_mask(sliding_window_size, softcapping, query_pre_attn_scalar):
def sliding_window_causal_mask(b, h, q_idx, kv_idx):
"""Causal mask and sliding window as proposed here:
https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
"""
causal_mask = q_idx >= kv_idx
if sliding_window_size is None:
# if no sliding window return causal mask
return causal_mask
else:
windowed_mask = q_idx - kv_idx <= sliding_window_size
return causal_mask & windowed_mask
def soft_capping_with_scaling(score, b, h, q_idx, kv_idx):
if query_pre_attn_scalar is None:
# usual scaling included in FlexAttention
score = score / softcapping
score = torch.tanh(score) #tanh_approx(score)
return score * softcapping
else:
score = score / softcapping * query_pre_attn_scalar**-0.5
score = torch.tanh(score) #tanh_approx(score)
return score * softcapping
return sliding_window_causal_mask, soft_capping_with_scaling
# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)
B=4
H=8
S=117
D=256 #256
mask_mod, score_mod = get_gemma2_flex_score_mask(WINDOW_SIZE, CAPPING, SCALE)
query = torch.randn(
B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
key = torch.randn(
B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
value = torch.randn(
B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
gradOut = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)
block_mask = create_block_mask(mask_mod=mask_mod,
B=1,
H=1,
Q_LEN=S,
KV_LEN=S,
device=query.device)
out = flex_attention(
query, key, value, score_mod=score_mod, block_mask=block_mask
)
print(out.shape)
The code runs fine if I don't compile the flex attention by commenting flex_attention = torch.compile(flex_attention, dynamic=False)
but it raises this error otherwise:
BackendCompilerFailed: backend='inductor' raised: OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or
num_stages may help.
So I disabled compilation and the code seems to be running but very very slowly (48s per iteration vs 1-2s on non flex implementation).
Maybe you could help me understand what is going on ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tagging @RdoubleA and @felipemello1 for their thoughts.
Just checking: which size Gemma-2 model are you testing with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logs I shared are from gemma2 2B, the code snippet is independent of the gemma architecture it's just a toy example.
I am currently running the qlora single device pipeline with 9B (without flex attention), I'll share the logs tomorrow (I'll push the changes to recipe as there are typos on the output path etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the given kernel options, flex attention can be compiled and the code runs (9b lora single device training). However, the code is terribly slow (29 tokens per second) and the loss turns to nan after one batch:
Step 1 | loss:87.6104507446289 lr:2.0000000000000003e-06 tokens_per_second_per_gpu:21.504354449638843
Step 2 | loss:nan lr:4.000000000000001e-06 tokens_per_second_per_gpu:29.156293709460176
I don't understand what I am doing wrong, the only obvious optimisation I see is to create one block mask for every layer while I am currently recreating the same block mask for every layer (line 593 in gemma2/_attention.py
). Nevertheless, I do not think that this is the current bottleneck.
Wouldn't it be better to go with the simpler implementation for now and switch to FlexAttention when it will work on more GPUs? or at least leave the choice of computation to the final user and default to the classical implementation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I see what you're saying.. I repro'd this on my end too so it is not a function of any custom kernel configs you're using. Let me look into this a bit more but in the meantime it seems like we shouldn't enable the flex version until we figure this out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the code to keep the flex attention implementation but disable it for now, until we have found a solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey sorry just catching up:
So 2.5 should not require multiple of 128 for sequence length. It is unfortunately pretty common for consumer gpus to hit the SharedMemory issue. I have a pr: pytorch/pytorch#137959 to drop default block sizes but still need to debug the failing test.
For being slow, it is expected that the tanh instruction is very slow compared to the inline assembly variant: https://github.com/pytorch-labs/attention-gym/blob/36f8bd5ded5b3469f7892099590bb2405cc8f744/attn_gym/mods/softcapping.py#L92.
It is actually quite hard generically to know what what block sizes should be used since the amount of shared memory depends on the captured buffers. I am working on a better solution but that is going to take some time unfortunately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you all figured out a solution to the out of resource: shared memory? Seems like any large hidden dim >=128 causes issues for me.
fd79f85
to
e999572
Compare
e999572
to
6f89920
Compare
I have pushed changes to the recipes for 9b and 27b (typos in folders' name). |
# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does 8bit work with distributed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should I do here? Change something or expect users to change parameters according to their hardware ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/gemma-2b/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint_dir: /tmp/gemma-2b/ | |
checkpoint_dir: /tmp/gemma-2-2b/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
if query_pre_attn_scalar is not None: | ||
self.scaling = query_pre_attn_scalar**-0.5 | ||
else: | ||
self.scaling = self.head_dim**-0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to add self.cache_enabled=False
here (then set it to True at the end of setup_cache
), otherwise this will error out. But this is kind of a gotcha, it's not obvious that you need this. cc @SalmanMohammadi we should think about how to make this more obvious to someone adding their own attention layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I added a comment to indicate why it's in the init (maybe @Optimox forked before then?)
# this flag indicates whether to update the kv-cache during forward
# passes. when disabled, we can have the cache setup but still
# perform normal forward passes
self.cache_enabled = False
Could we be clearer here? I agree we could use with a comment in setup_caches
explaining that you actually need to do this if you'd like to use the caches you've just setup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think I forked before this change, will make the change tomorrow thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
k = self.k_norm(k) | ||
|
||
# Update key-value cache | ||
if self.kv_cache is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.kv_cache is not None: | |
if self.kv_cache is not None and self.cache_enabled: |
should complement the cache enabled stuff earlier to match the other attention module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
|
||
# Fine-tuning arguments | ||
batch_size: 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we confident this'll fit on a single device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed batch size to 2 and accumulation to 8. What is the expected GPU? Is there a CI running everything? Otherwise I guess each user should be responsible to play with the batch to get something suitable for his GPU no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally we try ship configs which we know will work on some common hardware configuration (see examples here https://github.com/pytorch/torchtune?tab=readme-ov-file#memory-and-training-speed), so users can maintain the expectation that they can get started without any painful OOMs. Then they are free to play with the configs. We should make sure this config works with e.g. 1xA1000 - let me know if you need a hand here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SalmanMohammadi I do not have easy access to a A100, would appreciate if someone could run the code for the 27B params model and let me know what batch size I should set.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a quick look when we're ready to land. We can also reasonably mirror the batch size from the config of another similarly sized model already in the codebase.
# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token <HF_TOKEN> | ||
# | ||
# To launch on 4 devices, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's just me but when I try to run these distributed recipes I am hitting AssertionError: FSDP requires named DeviceMesh dims for ND parallelism
. It looks to me like we are actually entering _init_sharded_param
with a DTensor (see here), which does not happen with our other recipes. Need to figure out why this would be happening
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I think I cracked the case. See here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big mistake, thank you for catching that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logs after this fix look much better than previously for the 9b single lora pipeline!
log_gemma2-2b-single-lora_1729937021.txt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers aren't the losses too low? could it be because of the (non) causal sliding window attention problem ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed this comment before. I am gonna run some of your configs on my end now so will get back to you
""" | ||
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) | ||
|
||
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be inside the for loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
path: /tmp/gemma-2-27b/tokenizer.model | ||
|
||
# Dataset | ||
dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have packed: False
in dataset
, and log_peak_memory_stats: True
and compile: False
below, for every one of our configs.
Would it be annoying to ask if we could update these in the same way while we're here, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done I have updated all the configs to match the other PR!
@@ -27,6 +27,4 @@ | |||
"lora_gemma_7b", | |||
"qlora_gemma_2b", | |||
"qlora_gemma_7b", | |||
"gemma_hf_to_tune", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch : )
dd4cf33
to
54a237c
Compare
flex_causal_sliding_window, | ||
flex_tanh_soft_capping_with_scaling, | ||
) | ||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Why is this style of logger getting proliferated? We should be calling get_logger
from our utils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a copy paste on my side, let me know if you want me to change that on this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just change to torchtune.utils.get_logger
, but no strong preference here. Either way we should clean up other usages in a follow-up
Hi @Optimox sorry for the delay here. Given that the flex attention version is still not working properly, how do you feel about pulling it out of this PR? Then we can revisit in a follow-up. For context we are going to be cutting a release soon (targeting code freeze tomorrow) so don't want to block getting this in on something that we can address in a follow-up. Let me know if this makes sense to you. |
@ebsmothers yes no problem! What is the best way of handling this? Adding a new commit deleting the flex attention part of this branch ? Or creating a new PR without the flex attention part? |
@Optimox honestly whatever is easiest for you. I imagine just a commit deleting the flex code would be simplest, but feel free to do whatever makes sense to you! |
@ebsmothers I have removed the flex attention implementation from the code, let me know if there are still other changes to make! |
Gemma 2 and Gemma original implementations share different normalization but with | ||
the same name, so it is mandatory to differentiate their state dict in order to map | ||
correctly the different weights. | ||
They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key. | ||
See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for documenting this
sliding_mask = torch.triu( | ||
all_ones, -1 * self.sliding_window_size + 1 | ||
) * torch.tril(all_ones, self.sliding_window_size - 1) | ||
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) | ||
|
||
if self.softcapping is not None: | ||
output = output / self.softcapping | ||
output = torch.tanh(output) | ||
output = output * self.softcapping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add code comments explaining sliding window and the softcapping? (Also one for the magic value in the torch.where
line wouldn't hurt)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About this part of the code, I actually followed blindly the official pytoch implementation from Google here
I am not sure why they used this magic number instead of -torch.inf
...
About the sliding_mask
I am now worried that something is wrong here because of the way I defined the causal mask...
s_x = 10
sliding_window_size = 5
mask = torch.tril(
torch.ones(
size=(s_x, s_x),
dtype=torch.bool,
)
)
print(mask)
all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
all_ones, -1 * sliding_window_size + 1
) * torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
print(mask)
The final mask here does not seem to be causal anymore, and sliding future tokens are now accessible somehow...
Something like the following would seem better to me but is there a difference in the way masks are defined in gemma2 official code and torchtune?
s_x = 10
sliding_window_size = 5
mask = torch.tril(
torch.ones(
size=(s_x, s_x),
dtype=torch.bool,
)
)
mask = torch.where(mask==0, -torch.inf, 1)
print(mask)
all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
all_ones, -1 * sliding_window_size + 1
) * torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -torch.inf)
print(mask)
This seems concerning... what do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Optimox yeah good catch, that first mask definitely does not look right. I'm not that familiar with the official implementation, but looks like they are treating it as an additive mask here. In that case there should definitely not be anything that's not -torch.inf (or a very large negative number) above the diagonal.
This seems like a bug to me, maybe you can open an issue on the gemma_pytorch repo to confirm? Your second implementation looks correct to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to add another kitchen to the cook here so late. I thought I'd share some of my findings from the last time I worked on masking + SDPA. I'd reccomend checking this PR out pytorch/pytorch#133882 and the linked issues.
TLDR; this is the correct approach to use the attention mask like a "bias" by adding a very large negative to the q/k tensors. However, using -inf
as this negative number has been shown to produce NaN gradients for some rare corner cases (e.g. when an entire row is masked out). In Transformers, the approach is to use something like torch.finfo(dtype).min
[1] (which is maybe where the original magic number is coming from?)
import torch
x = torch.Tensor([[[float("-inf"), float("-inf"), float("-inf")]]])
softmax = torch.nn.Softmax(dim=-1)
softmax(x)
# tensor([[[nan, nan, nan]]])
dtype = torch.bfloat16
min_value = torch.finfo(dtype).min
# -3.3895313892515355e+38 - on MPS, this will vary depending on the hardware you're using
x = torch.Tensor([[[min_value, min_value, min_value]]])
softmax = torch.nn.Softmax(dim=-1)
softmax(x)
# tensor([[[0.3340, 0.3340, 0.3340]]])
Aside, as of torch 2.5 this is handled internally slightly differently. -inf
is used in the mask, softmax is performed, but then any rows in the original tensor which have entirely masked out rows are explicitly set to zero.
[1] Transformers follows this approach for their Gemma2 implementation. However, this apparently still causes issues for some dtypes so it's also been suggested to use torch.finfo(dtype).min / 2
- see huggingface/transformers#32390, or to just attend to all tokens in a row containing only padding tokens equally (huggingface/transformers@e22d913), but I'm not 100% about how this works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait so I am a bit confused here (this is also in reference to the issue opened by @Optimox on the gemma_pytorch repo). There are two separate questions here, right?
(1) Should values above the diagonal be unmasked?
(2) What is the right masked value for an additive mask?
I think both the gemma_pytorch discussion and @SalmanMohammadi's comments address (2), and that I am not so worried about. But I think (1) is more fundamental, and I can directly copy-paste the snippet from gemma_pytorch to get this:
import torch
s_x = 5
sliding_window_size = 3
mask = torch.tril(
torch.ones(
size=(s_x, s_x),
dtype=torch.bool,
)
)
all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
all_ones, -1 * sliding_window_size + 1
) * torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
print(mask)
...
tensor([[ 1.0000e+00, 0.0000e+00, 0.0000e+00, -2.3820e+38, -2.3820e+38],
[ 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, -2.3820e+38],
[ 1.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
[-2.3820e+38, 1.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00],
[-2.3820e+38, -2.3820e+38, 1.0000e+00, 1.0000e+00, 1.0000e+00]])
This demonstrates that there are values above the diagonal that are unmasked, no? (I will also reopen the issue on there just to confirm I am not missing something here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I understood is that there are two ways of defining masks (in bias mode):
- 0 for reachable, -inf (or very large negative number) for unreachable -> this is the expected input mask for gemma_pytorch implementation
- bolean mask (True for reachable, False otherwise) + torch.nn.functional.scaled_dot_product_attention which internally switches from boolean to the first implementation
if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
So here, since all previous models from torchtune used torch sdpa there was a mismatch between both implementation, that is why I added this conversion in the latest changes. This conversion should also work for block mask defined as boolean mask currently in torchtune code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I also caught up with @SalmanMohammadi offline about this. It seems like the initial definition of mask
in the code snippet I shared in my last comment did not match what they do in gemma_pytorch, so that was a misunderstanding on my part. I think your approach makes sense, will look at the code more closely to confirm the BlockMask case (though I guess if we're not supporting packed or flex yet it doesn't matter?)
q.mul_(self.scaling) | ||
output = torch.matmul(q, k.transpose(2, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add shape comments here too
x: torch.Tensor, | ||
y: Optional[torch.Tensor] = None, | ||
*, | ||
mask: Optional[_MaskType] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run any of the configs with packed=True
(i.e. when mask is a BlockMask
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't run a full training but I've checked that the code does not through any error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my latest comment: packed=True won't work at the moment!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments and questions but overall this is looking great!
@Optimox Any chance you can get to these last comments today? |
Co-authored-by: ebsmothers <[email protected]>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1835 +/- ##
==========================================
+ Coverage 67.05% 67.44% +0.38%
==========================================
Files 305 316 +11
Lines 15937 17143 +1206
==========================================
+ Hits 10687 11562 +875
- Misses 5250 5581 +331 ☔ View full report in Codecov by Sentry. |
I have pushed some minimal changes which take into account the fact that we are not using spda with gemma2 so the masks must be converted from boolean to True -> 0s and False -> -inf as discussed here. There is one last identified concern on my side, @ebsmothers I think I answered somewhere that packed dataset was working but actually it is not. If I manually disable flexattention (I don't know how to disable it locally or automatically) then I had a broadcasting issue which has been solved in my latest commit. So currently the code won't work for packed dataset for torch >= 2.5. |
recipes/configs/gemma2/27B_full.yaml
Outdated
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.gemma_27b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_component_: torchtune.models.gemma2.gemma_27b | |
_component_: torchtune.models.gemma2.gemma2_27b |
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.gemma2.qlora_gemma_27b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_component_: torchtune.models.gemma2.qlora_gemma_27b | |
_component_: torchtune.models.gemma2.qlora_gemma2_27b |
checkpoint_dir: /tmp/gemma-2-27b/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: 00024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all usages of max_filename should look like this instead
max_filename: 00024 | |
max_filename: "00024" |
@Optimox thanks for your patience in the review process here. Is this comment about unusually low losses still a concern? I ran a few of your configs on my end and the loss does increase pretty dramatically (though I also don't have a baseline). For reference here are some loss curves: Otherwise regarding the fact that packed is not yet supported, it seems like we should raise an error somewhere if that's the case. Maybe inside the attention if we receive a |
@ebsmothers yes the comment was before the fix of the slinding window attention mask. I have added a NotImplementedError for the BlockMasks and made the small fixes in the yaml files you pointed. Let me know if you think some other changes should be made! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this great new feature! And thanks for your patience and diligence throughout the review process. Very excited that we're now able to support Gemma 2 in torchtune.
Context
What is the purpose of this PR? Is it to
This is related to adding gemma2 support #1813
Changelog
What are the changes made in this PR?
*
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example