Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Llava Next Check for Visual Encoders Without CLS #35262

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 95 additions & 14 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def _init_weights(self, module):
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
If `"full"`, the full vision features are used. In models where there appears to be no CLS token, e.g.,
SigLIP, the vision_feature_select_strategy will use all vision features since there is no CLS token to
remove.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -643,7 +645,47 @@ def _merge_input_ids_with_image_features(

return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids

def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@staticmethod
def validate_image_feature_dims(
image_feature: torch.Tensor, height: int, width: int, has_cls: bool, vision_feature_select_strategy: str
):
"""
Validate the shape of the image features for the provided vision_feature_select_strategy.

Args:
image_feature (`torch.Tensor`)
Tensor of visual features with shape `(num_patches, image_length, embed_dim)`)
height (`int`)
The height of the image in patches.
width (`int`)
This width of the image in patches.
has_cls (`bool`)
Indicates whether or not the expected feature dimension should be decremented
by one due to the absence of a cls token in the visual encoder.
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""

base_image_feature = image_feature[0]
if vision_feature_select_strategy == "default" or not has_cls:
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")

def pack_image_features(
self,
image_features: List[torch.Tensor],
image_sizes: torch.Tensor,
vision_feature_select_strategy: str,
image_newline: Optional[torch.Tensor] = None,
has_cls: bool = True,
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.

Expand All @@ -656,6 +698,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
has_cls (`bool`)
Indicates whether or not the expected feature dimension should be decremented
by one due to the absence of a cls token in the visual encoder.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
Expand All @@ -668,13 +713,8 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
# Throw if the image feature dimension is incorrect
self.validate_image_feature_dims(image_feature, height, width, has_cls, vision_feature_select_strategy)

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
Expand Down Expand Up @@ -729,6 +769,38 @@ def get_image_features(
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
return self._get_image_features(
pixel_values, image_sizes, vision_feature_layer, vision_feature_select_strategy
)[0]

def _get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Additionally indicates whether or not the vision tower features appear to have CLS; if
no CLS is present in the features, all features are returned.

Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
tuple: tuple of length two containing:
- List of image feature tensor, each contains all the visual feature of all patches and
are of shape `(num_patches, image_length, embed_dim)`).
- bool indicating whether or not the image features have CLS.
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
Expand All @@ -748,13 +820,21 @@ def get_image_features(

image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":

# Check to see if the output feature dimension has CLS or not;
# if there is no CLS, we should take all of the features.
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patches = height * width
has_cls = num_patches != selected_image_feature.shape[1]

if vision_feature_select_strategy == "full" or not has_cls:
selected_image_feature = selected_image_feature
elif vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]

image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
return image_features, has_cls

@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -848,7 +928,7 @@ def forward(

image_features = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
image_features, has_cls = self._get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
Expand All @@ -861,6 +941,7 @@ def forward(
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
has_cls=has_cls,
)

if legacy_processing:
Expand Down
122 changes: 105 additions & 17 deletions src/transformers/models/llava_next_video/modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ def unpad_image(tensor, original_size):
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
If `"full"`, the full vision features are used.
If `"full"`, the full vision features are used. In models where there appears to be no CLS token, e.g.,
SigLIP, the vision_feature_select_strategy will use all vision features since there is no CLS token to
remove.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -676,7 +678,47 @@ def _merge_input_ids_with_image_features(

return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids

def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@staticmethod
def validate_image_feature_dims(
image_feature: torch.Tensor, height: int, width: int, has_cls: bool, vision_feature_select_strategy: str
):
"""
Validate the shape of the image features for the provided vision_feature_select_strategy.

Args:
image_feature (`torch.Tensor`)
Tensor of visual features with shape `(num_patches, image_length, embed_dim)`)
height (`int`)
The height of the image in patches.
width (`int`)
This width of the image in patches.
has_cls (`bool`)
Indicates whether or not the expected feature dimension should be decremented
by one due to the absence of a cls token in the visual encoder.
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""

base_image_feature = image_feature[0]
if vision_feature_select_strategy == "default" or not has_cls:
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")

def pack_image_features(
self,
image_features: List[torch.Tensor],
image_sizes: torch.Tensor,
vision_feature_select_strategy: str,
image_newline: Optional[torch.Tensor] = None,
has_cls: bool = True,
):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.

Expand All @@ -689,6 +731,9 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
has_cls (`bool`)
Indicates whether or not the expected feature dimension should be decremented
by one due to the absence of a cls token in the visual encoder.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
Expand All @@ -701,13 +746,8 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size

if vision_feature_select_strategy == "default":
expected_num_patches = height * width
elif vision_feature_select_strategy == "full":
expected_num_patches = height * width + 1
if expected_num_patches != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
# Throw if the image feature dimension is incorrect
self.validate_image_feature_dims(image_feature, height, width, has_cls, vision_feature_select_strategy)

num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
Expand Down Expand Up @@ -762,6 +802,38 @@ def get_image_features(
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
return self._get_image_features(
pixel_values, image_sizes, vision_feature_layer, vision_feature_select_strategy
)[0]

def _get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: int,
vision_feature_select_strategy: str,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Additionally indicates whether or not the vision tower features appear to have CLS; if
no CLS is present in the features, all features are returned.

Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`int`):
The index of the layer to select the vision feature.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
tuple: tuple of length two containing:
- List of image feature tensor, each contains all the visual feature of all patches and
are of shape `(num_patches, image_length, embed_dim)`).
- bool indicating whether or not the image features have CLS.
"""
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
Expand All @@ -781,13 +853,21 @@ def get_image_features(

image_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":

# Check to see if the output feature dimension has CLS or not;
# if there is no CLS, we should take all of the features.
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patches = height * width
has_cls = num_patches != selected_image_feature.shape[1]

if vision_feature_select_strategy == "full" or not has_cls:
selected_image_feature = selected_image_feature
elif vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]

image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
return image_features, has_cls

@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -929,7 +1009,7 @@ def forward(

image_features = feature_lens = None
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
image_features, has_cls = self._get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
Expand All @@ -940,6 +1020,7 @@ def forward(
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
has_cls=has_cls,
)

video_features = video_feature_lens = None
Expand Down Expand Up @@ -1147,10 +1228,17 @@ def get_video_features(
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
selected_video_features = video_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":

# Check to see if the output feature dimension has CLS or not;
# if there is no CLS, we should take all of the features.
patches_height = patches_width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patches = patches_height * patches_width
has_cls = num_patches != selected_video_features.shape[1]

if vision_feature_select_strategy == "full" or not has_cls:
selected_video_features = selected_video_features
elif vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]

# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
Expand Down
Loading