Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal collater with interleaved image, cross-attention mask padding #1156

Merged
merged 21 commits into from
Sep 11, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Jul 9, 2024

Context

Add the batch collater for multimodal image + text datasets. The collater supports the following:

  • Tiled images
  • Multiple images per sample
  • Cross-attention masks

Inputs must be samples from the multimodal dataset post tiling, post transform:

  • "tokens": List[int] of length text_seq_len, varies across samples
  • "labels": List[int] of length text_seq_len, varies across samples
  • "images": List[Tensor], each with shape (n_tiles, c, h, w)
  • "encoder_mask": List[Tensor], each with shape (text_seq_len, image_seq_len)
  • "aspect_ratio": List[Tensor], each with shape (h_ratio, w_ratio)

It performs the following actions:
(1) Pad text sequence and encoder mask to the longest sequence length in the batch
(2) Pad image tensors in the tile dimension with zeros to the largest number of tiles in the batch
(3) Add empty images of zeros to samples up to max number of images in the batch
(4) Pad aspect ratios with (1,1) for all added padding images

Feedback requested:

  • The name is a bit verbose, but padded_collate_multimodal is too vague as well
  • We are padding numerous dimensions here, requiring nested for loops with runtime O(total num of images in batch) and running it twice. Would be great to see if we can simplify/optimize/vectorize it further

Changelog

  • Added the collate function + unit test

Test plan

pytest tests/torchtune/utils/test_collate.py

Docs

image image

Copy link

pytorch-bot bot commented Jul 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1156

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 85dbb95 with merge base 8451b0d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 9, 2024
@codecov-commenter
Copy link

codecov-commenter commented Sep 4, 2024

Codecov Report

Attention: Patch coverage is 98.59155% with 1 line in your changes missing coverage. Please review.

Project coverage is 73.42%. Comparing base (8451b0d) to head (6a0b462).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/data/_collate.py 97.22% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1156      +/-   ##
==========================================
+ Coverage   71.21%   73.42%   +2.20%     
==========================================
  Files         287      287              
  Lines       14058    14128      +70     
==========================================
+ Hits        10011    10373     +362     
+ Misses       4047     3755     -292     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good way to handle a complicated topic. I think we should name this padded_collate_vision_text for now and pretend that it's general until it's not. My suspicion is that we'll have to move this into the model folder in the future as it's so model specific. After you address the comments I left, I'll test this for you on my script.

torchtune/data/_collate.py Outdated Show resolved Hide resolved
torchtune/data/_collate.py Outdated Show resolved Hide resolved
... "encoder_mask": [torch.ones(2, 5 * 4)],
... },
... ]
>>> model_inputs = padded_collate_vision_text(batch=batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name doesn't match the current name. I actually prefer padded_collate_vision_text as it's more straight forward and we can either generalize this function or split and rename as we get more vision_text models in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ 1

[8, 9, -100, -100]])
>>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w)
torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_num_images, max_num_tiles, tokens_per_tile * max_num_tiles)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should actually be [2, 4, 40] since cross attention is text vs image sequence and the image sequence is num_imagesnum_tilestokens_per_tile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what about batch size?

torch.Size([2, 2, 4, 1, 1, 1])
>>> print(model_inputs["encoder_mask"].shape) # (bsz, max_num_images, max_num_tiles, tokens_per_tile * max_num_tiles)
torch.Size([2, 2, 4, 20])
>>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this should be [2, 2, 2] or [2, 4]. @felipemello1 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aspect_ratio should be (bsz, max_num_images, 2), and then in the clip we reshape:

aspect_ratio = aspect_ratio.reshape(bsz_and_n_imgs, 2)

https://github.com/pytorch/torchtune/blob/82c232d0679ddef3fc419cdc18af758b98b4da05/torchtune/modules/vision_transformer.py#L354C9-L354C63

collated_text = padded_collate_sft(text_only, padding_idx, ignore_idx)
max_seq_len = collated_text["tokens"].shape[-1]

# TODO: Figure out how to make this more efficient or vectorized. Setting
Copy link
Contributor

@felipemello1 felipemello1 Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt think too much about it, but maybe:

  1. do a first pass to check the max of each dimension.
  2. create a tensor with all zeros. Pre allocating should simplify all the padding.
  3. Add the input to the tensor correct line: eg. tensor[0] += sample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-allocating would definitely simplify the code. I would still need to loop through each individual image though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave this as a follow-up though in the interest of time

@RdoubleA
Copy link
Contributor Author

RdoubleA commented Sep 5, 2024

I think we should name this padded_collate_vision_text for now and pretend that it's general until it's not.

I actually think we should go the other way and keep this overspecified until we have a concrete use case to generalize it, I'd rather be overspecific than mislead users into thinking this can be used for any multimodal model

- "tokens": List[int] of length text_seq_len, varies across samples
- "labels": List[int] of length text_seq_len, varies across samples
- "encoder_input": Dict[str, List[torch.Tensor]]
- "images": List[torch.Tensor], each with shape (n_tiles, c, h, w)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say somewhere c, h, w = channel, height, width?

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tested this and everything seems to be working accurately. The only issue right now is that this can't be used for inference. We need to expose padded_direction for text and not expect "labels" when padded_direction=="left". Also, I'd like to propose "padded_collate_tiled_images_and_mask".

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@pbontrager pbontrager merged commit 377abc0 into pytorch:main Sep 11, 2024
17 checks passed
@RdoubleA RdoubleA deleted the mm_collator branch September 11, 2024 22:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants