Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom prepare_for_training logic to ECD model for LLM encoder adapter initialization #3874

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions ludwig/encoders/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,8 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):

self.config = encoder_config

self.adapter_is_initialized = False

self.model_name = self.config.base_model
self.model_config = AutoConfig.from_pretrained(self.config.base_model)

Expand All @@ -2421,8 +2423,6 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):

self.attention_masks = None

self.prepare_for_training()

clear_data_cache()

@staticmethod
Expand Down Expand Up @@ -2451,6 +2451,8 @@ def initialize_adapter(self):
self.model.print_trainable_parameters()
logger.info("==================================================")

self.adapter_is_initialized = True

def prepare_for_training(self):
# TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this.
if self.config.quantization:
Expand Down Expand Up @@ -2485,7 +2487,7 @@ def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool):
# contents of the state_dict.
# The three args to this method are supplied by Module.state_dict
# https://github.com/pytorch/pytorch/blob/8739d1e3f9b08f4282fe79fc8dacd781d16913ff/torch/nn/modules/module.py#L1824
if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
# get_peft_model_state_dict geneates a state dict that only contains the adapter weights
from peft.utils.save_and_load import get_peft_model_state_dict

Expand All @@ -2498,7 +2500,7 @@ def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool):
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
destination = super().state_dict(destination, prefix=prefix, keep_vars=keep_vars)

if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type]
exclude_model_keys = [k for k in destination.keys() if adapter_type_prefix not in k]

Expand All @@ -2518,7 +2520,7 @@ def _load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
# When using an adapter, only the adapter weights are saved, and so we only want to load those weights.
# Under the hood, PEFT alters the names of the parameters, which leads to an "unexpected keys" error when
# using strict mode. This block uses PEFT's version of `load_state_dict` to handle loading in weights.
Expand All @@ -2540,7 +2542,7 @@ def remove_missing_non_adapter_keys(self, module, incompatible_keys):
"""
# If no adapter was used, `LLMEncoder.load_state_dict` should use the default `torch.Module.load_state_dict`
# code path to load weights and no modification should be necessary.
if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type]
missing_keys, unexpected_keys = incompatible_keys

Expand Down
14 changes: 13 additions & 1 deletion ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from ludwig.combiners.combiners import create_combiner
from ludwig.constants import MODEL_ECD
from ludwig.constants import MODEL_ECD, MODEL_LLM
from ludwig.globals import MODEL_WEIGHTS_FILE_NAME
from ludwig.models.base import BaseModel
from ludwig.schema.model_types.ecd import ECDModelConfig
Expand Down Expand Up @@ -56,6 +56,18 @@ def __init__(
# After constructing all layers, clear the cache to free up memory
clear_data_cache()

def prepare_for_training(self):
# 1/10/23: For parity with how the LLM model type sets up adapters and quantization, LLM encoders should call
# `prepare_for_training` at training time rather than at initialization. This loop searches for input features
# using the LLM encoder and calls `prepare_for_training` on those encoders only. No other changes should be
# made to the ECD model itself or any other encoders.
for feature in self.config_obj.input_features:
encoder_type = feature.encoder.type
if encoder_type == MODEL_LLM:
feature_name = feature.name
encoder = self.input_features.get(feature_name)
encoder.prepare_for_training()

def encode(
self,
inputs: Union[
Expand Down
34 changes: 29 additions & 5 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,37 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, model_config):
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)

# Test initializing with an adapter
from peft import PeftModel

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

assert encoder.model_name == encoder_config.base_model
assert isinstance(encoder.model, PeftModel)
assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) # Check adapter was initialized
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: str):
from peft import PeftModel

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

# Initialize the adapter
encoder.prepare_for_training()

# At this point, the adapter should be initialized and the state dict should contain adapter parameters
assert isinstance(encoder.model, PeftModel)
assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

def test_save_to_state_dict(self, encoder_config: LLMEncoderConfig, tmpdir):
# With no adapter, the state dict should only contain the model parameters
encoder = LLMEncoder(encoder_config=encoder_config)
Expand All @@ -106,6 +124,8 @@ def test_save_to_state_dict_adapter(self, encoder_config: LLMEncoderConfig, adap
# With an adapter, the state dict should only contain adapter parameters
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
# Initialize the adapters
encoder.prepare_for_training()
assert all(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

@pytest.mark.parametrize("wrap", [False, True], ids=["no_wrapper", "with_wrapper"])
Expand Down Expand Up @@ -151,6 +171,10 @@ def weights_init(m):
encoder1 = LLMEncoder(encoder_config=encoder_config_with_adapter)
encoder2 = LLMEncoder(encoder_config=encoder_config_with_adapter)

# Initialize the adapters
encoder1.prepare_for_training()
encoder2.prepare_for_training()

if wrap:
encoder1 = WrapperModule(encoder1)
encoder2 = WrapperModule(encoder2)
Expand Down
Loading