Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLoRA with bias + Llama 3.2 Vision QLoRA configs #1726

Merged
merged 12 commits into from
Oct 25, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 1, 2024

After opening pytorch/ao#979 on torchao, it was pointed out to me that I was overcomplicating things.. we can just keep the bias in bf16, which is apparently a pretty standard thing to do (ref).

So this PR does exactly that.. just let bias stay in the higher precision for our FrozenNF4Linear, LoRALinear, and DoRALinear when we set quantize_base=True.

Test plan

Added test cases for LoRALinear and FrozenNF4Linear with bias=True.

Fun fact I discovered while writing the LoRALinear test: if x, weight, bias are all bf16 then F.linear(x, weight, bias) != F.linear(x, weight) + bias (repro here). So I left that test case out.

E2E tests:

tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=qlora-32-vision \
metric_logger.name=single-device max_steps_per_epoch=500

and

tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/11B_qlora \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=qlora-32-vision \
metric_logger.name=distributed max_steps_per_epoch=500

Loss curves from both these runs (compared to analogous LoRA runs as a baseline):

Screenshot 2024-10-23 at 5 29 21 PM

Also run with DoRA applied to MLP and output layers as a sanity check:

tune run lora_finetune_single_device --config llama3_2_vision/11B_qlora_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=qlora-32-vision \
metric_logger.name=single-device max_steps_per_epoch=5 model.apply_lora_to_mlp=True model.apply_lora_to_output=True model.use_dora=True max_steps_per_epoch=5 gradient_accumulation_steps=1
...

Copy link

pytorch-bot bot commented Oct 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1726

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 09491e6 with merge base 17ba37d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 1, 2024
@ebsmothers ebsmothers marked this pull request as draft October 1, 2024 14:07
@joecummings joecummings mentioned this pull request Oct 15, 2024
34 tasks
@ebsmothers ebsmothers changed the title [wip] QLoRA with bias + Llama 3.2 Vision QLoRA configs QLoRA with bias + Llama 3.2 Vision QLoRA configs Oct 23, 2024
@ebsmothers ebsmothers marked this pull request as ready for review October 23, 2024 20:59
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant work. The kind of work they sing about in old minstrel songs.

One question.

@@ -81,7 +81,7 @@ enable_activation_offloading: False
dtype: bf16

# Logging
output_dir: /tmp/full-llama3.2-vision-finetune
output_dir: /tmp/lora-llama3.2-vision-finetune
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops

@@ -59,9 +55,10 @@ def test_state_dict(self, dtype):
assert isinstance(state_dict["weight"], NF4Tensor)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_output_dtype(self, dtype):
@pytest.mark.parametrize("bias", [True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the point of adding bias to this test? The dtype isn't changing and you're only checking the dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed it's pretty trivial but I'd like to at least build FrozenNF4Linear with bias somewhere in our unit tests, and the overhead of this unit test is tiny

use_bias=use_bias,
quantize_base=True,
)
# fixed_init_model(qlora_linear, dtype=torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to comment this out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah oops, lemme update

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Evan for adding this! Added mostly questions, but overall looks good. Could you also run a test with lora applied to mlp and output as well, since those also have bias values?


Args:
in_dim (int): input dimension
out_dim (int): output dimension
device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default
device given by `torch.get_default_device()`.
bias (bool): whether to include bias in the original linear layer. Default: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove "original"

torchtune/modules/low_precision/nf4_linear.py Show resolved Hide resolved
torchtune/modules/low_precision/nf4_linear.py Show resolved Hide resolved
@@ -123,6 +119,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
if self._quantize_base:
out = linear_nf4(input=x, weight=self.weight)
if self.use_bias:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we use torchao linear_nf4 here and not our own?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow.. we don't have our own linear_nf4. Even our FrozenNF4Linear is just a wrapper around their linear_nf4

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, why do we not reuse the FrozenNF4Linear class here so we don't have to define this bias solution in multiple places?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed a bit offline. It's a nice idea but going to punt on it for now


# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that this is identical to lora except for this line? Whenever you do a merge you should re-check that assumption.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, did this

self.weight.requires_grad_(False)
if self.bias is not None:
self.bias.requires_grad_(False)
self.nf4_weight = to_nf4(self.weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want this "self." here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good point, maybe not strictly necessary

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 92.63158% with 7 lines in your changes missing coverage. Please review.

Project coverage is 67.89%. Comparing base (73aa126) to head (09491e6).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
...tune/models/llama3_2_vision/_component_builders.py 46.15% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1726      +/-   ##
==========================================
- Coverage   70.25%   67.89%   -2.36%     
==========================================
  Files         309      308       -1     
  Lines       16285    16301      +16     
==========================================
- Hits        11441    11068     -373     
- Misses       4844     5233     +389     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines +456 to +465
encoder = Llama3VisionEncoder(clip=clip, projection_head=projection_head)

if quantize_base:
# For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly
# so as to not increase peak memory
encoder._register_state_dict_hook(
partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
)

return encoder
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of miscellaneous linter changes in this file, this is the only one of substance

new_key = prefix + "embedding.weight"
state_dict[new_key] = state_dict[key]
del state_dict[key]
if state_dict:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pbontrager this is my hack to support DoRA. Lmk if any concerns

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw why you do this. This seems like a general safety check that would be good to have in all of the load_state_dict hooks as this case could come up anytime "strict=False". I'll approve this but could you add this change here or in a followup in the other load hooks?

@pbontrager pbontrager merged commit e030626 into pytorch:main Oct 25, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants