-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multimodal collater with interleaved image, cross-attention mask padding #1156
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1156
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 85dbb95 with merge base 8451b0d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1156 +/- ##
==========================================
+ Coverage 71.21% 73.42% +2.20%
==========================================
Files 287 287
Lines 14058 14128 +70
==========================================
+ Hits 10011 10373 +362
+ Misses 4047 3755 -292 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really good way to handle a complicated topic. I think we should name this padded_collate_vision_text for now and pretend that it's general until it's not. My suspicion is that we'll have to move this into the model folder in the future as it's so model specific. After you address the comments I left, I'll test this for you on my script.
torchtune/data/_collate.py
Outdated
... "encoder_mask": [torch.ones(2, 5 * 4)], | ||
... }, | ||
... ] | ||
>>> model_inputs = padded_collate_vision_text(batch=batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name doesn't match the current name. I actually prefer padded_collate_vision_text as it's more straight forward and we can either generalize this function or split and rename as we get more vision_text models in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+ 1
torchtune/data/_collate.py
Outdated
[8, 9, -100, -100]]) | ||
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) | ||
torch.Size([2, 2, 4, 1, 1, 1]) | ||
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_num_images, max_num_tiles, tokens_per_tile * max_num_tiles) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should actually be [2, 4, 40] since cross attention is text vs image sequence and the image sequence is num_imagesnum_tilestokens_per_tile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but what about batch size?
torch.Size([2, 2, 4, 1, 1, 1]) | ||
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_num_images, max_num_tiles, tokens_per_tile * max_num_tiles) | ||
torch.Size([2, 2, 4, 20]) | ||
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this should be [2, 2, 2] or [2, 4]. @felipemello1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aspect_ratio should be (bsz, max_num_images, 2), and then in the clip we reshape:
aspect_ratio = aspect_ratio.reshape(bsz_and_n_imgs, 2)
collated_text = padded_collate_sft(text_only, padding_idx, ignore_idx) | ||
max_seq_len = collated_text["tokens"].shape[-1] | ||
|
||
# TODO: Figure out how to make this more efficient or vectorized. Setting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didnt think too much about it, but maybe:
- do a first pass to check the max of each dimension.
- create a tensor with all zeros. Pre allocating should simplify all the padding.
- Add the input to the tensor correct line: eg. tensor[0] += sample
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pre-allocating would definitely simplify the code. I would still need to loop through each individual image though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave this as a follow-up though in the interest of time
I actually think we should go the other way and keep this overspecified until we have a concrete use case to generalize it, I'd rather be overspecific than mislead users into thinking this can be used for any multimodal model |
- "tokens": List[int] of length text_seq_len, varies across samples | ||
- "labels": List[int] of length text_seq_len, varies across samples | ||
- "encoder_input": Dict[str, List[torch.Tensor]] | ||
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say somewhere c, h, w
= channel, height, width
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tested this and everything seems to be working accurately. The only issue right now is that this can't be used for inference. We need to expose padded_direction for text and not expect "labels" when padded_direction=="left". Also, I'd like to propose "padded_collate_tiled_images_and_mask".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Context
Add the batch collater for multimodal image + text datasets. The collater supports the following:
Inputs must be samples from the multimodal dataset post tiling, post transform:
It performs the following actions:
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
(2) Pad image tensors in the tile dimension with zeros to the largest number of tiles in the batch
(3) Add empty images of zeros to samples up to max number of images in the batch
(4) Pad aspect ratios with (1,1) for all added padding images
Feedback requested:
Changelog
Test plan
pytest tests/torchtune/utils/test_collate.py
Docs