-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QLoRA with bias + Llama 3.2 Vision QLoRA configs #1726
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1726
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 09491e6 with merge base 17ba37d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brilliant work. The kind of work they sing about in old minstrel songs.
One question.
@@ -81,7 +81,7 @@ enable_activation_offloading: False | |||
dtype: bf16 | |||
|
|||
# Logging | |||
output_dir: /tmp/full-llama3.2-vision-finetune | |||
output_dir: /tmp/lora-llama3.2-vision-finetune |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whoops
@@ -59,9 +55,10 @@ def test_state_dict(self, dtype): | |||
assert isinstance(state_dict["weight"], NF4Tensor) | |||
|
|||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) | |||
def test_output_dtype(self, dtype): | |||
@pytest.mark.parametrize("bias", [True, False]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the point of adding bias to this test? The dtype isn't changing and you're only checking the dtype?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed it's pretty trivial but I'd like to at least build FrozenNF4Linear
with bias somewhere in our unit tests, and the overhead of this unit test is tiny
use_bias=use_bias, | ||
quantize_base=True, | ||
) | ||
# fixed_init_model(qlora_linear, dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to comment this out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah oops, lemme update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Evan for adding this! Added mostly questions, but overall looks good. Could you also run a test with lora applied to mlp and output as well, since those also have bias values?
|
||
Args: | ||
in_dim (int): input dimension | ||
out_dim (int): output dimension | ||
device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default | ||
device given by `torch.get_default_device()`. | ||
bias (bool): whether to include bias in the original linear layer. Default: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove "original"
@@ -123,6 +119,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
""" | |||
if self._quantize_base: | |||
out = linear_nf4(input=x, weight=self.weight) | |||
if self.use_bias: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we use torchao linear_nf4 here and not our own?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow.. we don't have our own linear_nf4. Even our FrozenNF4Linear
is just a wrapper around their linear_nf4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant, why do we not reuse the FrozenNF4Linear class here so we don't have to define this bias solution in multiple places?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed a bit offline. It's a nice idea but going to punt on it for now
|
||
# Model arguments | ||
model: | ||
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you confirm that this is identical to lora except for this line? Whenever you do a merge you should re-check that assumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, did this
self.weight.requires_grad_(False) | ||
if self.bias is not None: | ||
self.bias.requires_grad_(False) | ||
self.nf4_weight = to_nf4(self.weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we want this "self." here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good point, maybe not strictly necessary
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1726 +/- ##
==========================================
- Coverage 70.25% 67.89% -2.36%
==========================================
Files 309 308 -1
Lines 16285 16301 +16
==========================================
- Hits 11441 11068 -373
- Misses 4844 5233 +389 ☔ View full report in Codecov by Sentry. |
encoder = Llama3VisionEncoder(clip=clip, projection_head=projection_head) | ||
|
||
if quantize_base: | ||
# For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly | ||
# so as to not increase peak memory | ||
encoder._register_state_dict_hook( | ||
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) | ||
) | ||
|
||
return encoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bunch of miscellaneous linter changes in this file, this is the only one of substance
new_key = prefix + "embedding.weight" | ||
state_dict[new_key] = state_dict[key] | ||
del state_dict[key] | ||
if state_dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pbontrager this is my hack to support DoRA. Lmk if any concerns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw why you do this. This seems like a general safety check that would be good to have in all of the load_state_dict hooks as this case could come up anytime "strict=False". I'll approve this but could you add this change here or in a followup in the other load hooks?
After opening pytorch/ao#979 on torchao, it was pointed out to me that I was overcomplicating things.. we can just keep the bias in bf16, which is apparently a pretty standard thing to do (ref).
So this PR does exactly that.. just let bias stay in the higher precision for our
FrozenNF4Linear
,LoRALinear
, andDoRALinear
when we setquantize_base=True
.Test plan
Added test cases for LoRALinear and FrozenNF4Linear with bias=True.
Fun fact I discovered while writing the LoRALinear test: if x, weight, bias are all bf16 then
F.linear(x, weight, bias) != F.linear(x, weight) + bias
(repro here). So I left that test case out.E2E tests:
and
Loss curves from both these runs (compared to analogous LoRA runs as a baseline):
Also run with DoRA applied to MLP and output layers as a sanity check: