Skip to content

Commit

Permalink
Fix CLS handling for llava next video
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Dec 16, 2024
1 parent 8202e79 commit 4858bbb
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 73 deletions.
122 changes: 105 additions & 17 deletions src/transformers/models/llava_next_video/modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ def unpad_image(tensor, original_size):
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
If `"full"`, the full vision features are used. In models where there appears to be no CLS token, e.g.,
SigLIP, the vision_feature_select_strategy will use all vision features since there is no CLS token to
remove.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -676,7 +678,47 @@ def _merge_input_ids_with_image_features(

return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids

def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@staticmethod
def validate_image_feature_dims(
image_feature: torch.Tensor, height: int, width: int, has_cls: bool, vision_feature_select_strategy: str
):
"""
Validate the shape of the image features for the provided vision_feature_select_strategy.
Args:
image_feature (`torch.Tensor`)
Tensor of visual features with shape `(num_patches, image_length, embed_dim)`)
height (`int`)
The height of the image in patches.
width (`int`)
This width of the image in patches.
has_cls (`bool`)
Indicates whether or not the expected feature dimension should be decremented
by one due to the absence of a cls token in the visual encoder.
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""

base_image_feature = image_feature[0]
if vision_feature_select_strategy == "default" or not has_cls:
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")

def pack_image_features(
self,
image_features: List[torch.Tensor],
image_sizes: torch.Tensor,
vision_feature_select_strategy: str,
image_newline: Optional[torch.Tensor] = None,
has_cls: bool = True,
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Expand All @@ -689,6 +731,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
has_cls (`bool`)
Indicates whether or not the expected feature dimension should be decremented
by one due to the absence of a cls token in the visual encoder.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
Expand All @@ -701,13 +746,8 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
# Throw if the image feature dimension is incorrect
self.validate_image_feature_dims(image_feature, height, width, has_cls, vision_feature_select_strategy)

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
Expand Down Expand Up @@ -762,6 +802,38 @@ def get_image_features(
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
return self._get_image_features(
pixel_values, image_sizes, vision_feature_layer, vision_feature_select_strategy
)[0]

def _get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Additionally indicates whether or not the vision tower features appear to have CLS; if
no CLS is present in the features, all features are returned.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
tuple: tuple of length two containing:
- List of image feature tensor, each contains all the visual feature of all patches and
are of shape `(num_patches, image_length, embed_dim)`).
- bool indicating whether or not the image features have CLS.
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
Expand All @@ -781,13 +853,21 @@ def get_image_features(

image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":

# Check to see if the output feature dimension has CLS or not;
# if there is no CLS, we should take all of the features.
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patches = height * width
has_cls = num_patches != selected_image_feature.shape[1]

if vision_feature_select_strategy == "full" or not has_cls:
selected_image_feature = selected_image_feature
elif vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]

image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
return image_features, has_cls

@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -929,7 +1009,7 @@ def forward(

image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
image_features, has_cls = self._get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
Expand All @@ -940,6 +1020,7 @@ def forward(
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
has_cls=has_cls,
)

video_features = video_feature_lens = None
Expand Down Expand Up @@ -1147,10 +1228,17 @@ def get_video_features(
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_video_features = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":

# Check to see if the output feature dimension has CLS or not;
# if there is no CLS, we should take all of the features.
patches_height = patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patches = patches_height * patches_width
has_cls = num_patches != selected_video_features.shape[1]

if vision_feature_select_strategy == "full" or not has_cls:
selected_video_features = selected_video_features
elif vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]

# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
LlavaNextCausalLMOutputWithPast,
LlavaNextForConditionalGeneration,
LlavaNextPreTrainedModel,
image_size_to_num_patches,
)

from ...configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -229,57 +228,6 @@ def __init__(self, config: LlavaNextVideoConfig, **super_kwargs):
self.vision_resampler = LlavaNextVideoPooler(config)
self.post_init()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")

image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features

def get_video_features(
self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str
):
Expand All @@ -302,10 +250,17 @@ def get_video_features(
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_video_features = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":

# Check to see if the output feature dimension has CLS or not;
# if there is no CLS, we should take all of the features.
patches_height = patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patches = patches_height * patches_width
has_cls = num_patches != selected_video_features.shape[1]

if vision_feature_select_strategy == "full" or not has_cls:
selected_video_features = selected_video_features
elif vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]

# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
Expand Down Expand Up @@ -451,7 +406,7 @@ def forward(

image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
image_features, has_cls = self._get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
Expand All @@ -462,6 +417,7 @@ def forward(
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
has_cls=has_cls,
)

video_features = video_feature_lens = None
Expand Down
47 changes: 47 additions & 0 deletions tests/models/llava_next_video/test_modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
"""Testing suite for the PyTorch Llava-NeXT-Video model."""

import unittest
from unittest.mock import patch

import numpy as np
import requests
from huggingface_hub import hf_hub_download
from parameterized import parameterized

from transformers import (
AutoProcessor,
LlavaNextVideoConfig,
LlavaNextVideoForConditionalGeneration,
is_torch_available,
is_vision_available,
modeling_outputs,
)
from transformers.testing_utils import (
cleanup,
Expand Down Expand Up @@ -341,6 +344,50 @@ def test_mismatching_num_image_tokens(self):
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

@parameterized.expand(
[
(4, True, "default"),
(5, True, "full"),
(4, False, "default"),
(4, False, "full"),
],
)
def test_visual_encoder_cls_behavior_images(self, num_expected_features, has_cls, strategy):
"""
Test that we correctly handle error checking for the dimensions of visual encoders
that do/don't have CLS tokens. If the visual encoder has no CLS, i.e., produces a
feature count of # patches height X # patches width, the strategy is always treated
as FULL, because we have nothing to remove.
"""
# Mock the tower outputs; 5 means we have CLS, 4 means we don't,
# since area in patches for the test model is 2x2 -> 4 features
num_features = 5 if has_cls else 4
tower_output = modeling_outputs.BaseModelOutputWithPooling(
hidden_states=[torch.Tensor(15, num_features, 32).to(torch_device) for x in range(3)]
)
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
# Get the image features using the selected strategy and mocked tower outputs
with patch.object(model.vision_tower, "forward", return_value=tower_output):
image_features = model.get_image_features(
pixel_values=input_dict["pixel_values"],
image_sizes=input_dict["image_sizes"],
vision_feature_layer=-1,
vision_feature_select_strategy=strategy,
)
# Ensure that that our dimensions match up based on tower output and strategy
assert image_features[0].shape[1] == num_expected_features
# Run the validation that is used when packing images
height = width = config.vision_config.image_size // config.vision_config.patch_size
model.validate_image_feature_dims(
image_feature=image_features[0],
height=height,
width=width,
has_cls=has_cls,
vision_feature_select_strategy=strategy,
)

@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
Expand Down

0 comments on commit 4858bbb

Please sign in to comment.