From e7d86e4cefe6c82c8b3b30ea3f1163d948d9b04f Mon Sep 17 00:00:00 2001 From: Jeff Kinnison Date: Thu, 11 Jan 2024 12:26:09 -0500 Subject: [PATCH] Add custom `prepare_for_trianing` logic to ECD model for LLM encoder adapter initialization (#3874) --- ludwig/encoders/text_encoders.py | 14 +++++---- ludwig/models/ecd.py | 14 ++++++++- tests/ludwig/encoders/test_llm_encoders.py | 34 ++++++++++++++++++---- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/ludwig/encoders/text_encoders.py b/ludwig/encoders/text_encoders.py index bcf6dd70300..b445732a609 100644 --- a/ludwig/encoders/text_encoders.py +++ b/ludwig/encoders/text_encoders.py @@ -2395,6 +2395,8 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs): self.config = encoder_config + self.adapter_is_initialized = False + self.model_name = self.config.base_model self.model_config = AutoConfig.from_pretrained(self.config.base_model) @@ -2421,8 +2423,6 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs): self.attention_masks = None - self.prepare_for_training() - clear_data_cache() @staticmethod @@ -2451,6 +2451,8 @@ def initialize_adapter(self): self.model.print_trainable_parameters() logger.info("==================================================") + self.adapter_is_initialized = True + def prepare_for_training(self): # TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this. if self.config.quantization: @@ -2485,7 +2487,7 @@ def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool): # contents of the state_dict. # The three args to this method are supplied by Module.state_dict # https://github.com/pytorch/pytorch/blob/8739d1e3f9b08f4282fe79fc8dacd781d16913ff/torch/nn/modules/module.py#L1824 - if self.config.adapter: + if self.config.adapter and self.adapter_is_initialized: # get_peft_model_state_dict geneates a state dict that only contains the adapter weights from peft.utils.save_and_load import get_peft_model_state_dict @@ -2498,7 +2500,7 @@ def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool): def state_dict(self, *args, destination=None, prefix="", keep_vars=False): destination = super().state_dict(destination, prefix=prefix, keep_vars=keep_vars) - if self.config.adapter: + if self.config.adapter and self.adapter_is_initialized: adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type] exclude_model_keys = [k for k in destination.keys() if adapter_type_prefix not in k] @@ -2518,7 +2520,7 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) - if self.config.adapter: + if self.config.adapter and self.adapter_is_initialized: # When using an adapter, only the adapter weights are saved, and so we only want to load those weights. # Under the hood, PEFT alters the names of the parameters, which leads to an "unexpected keys" error when # using strict mode. This block uses PEFT's version of `load_state_dict` to handle loading in weights. @@ -2540,7 +2542,7 @@ def remove_missing_non_adapter_keys(self, module, incompatible_keys): """ # If no adapter was used, `LLMEncoder.load_state_dict` should use the default `torch.Module.load_state_dict` # code path to load weights and no modification should be necessary. - if self.config.adapter: + if self.config.adapter and self.adapter_is_initialized: adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type] missing_keys, unexpected_keys = incompatible_keys diff --git a/ludwig/models/ecd.py b/ludwig/models/ecd.py index d8d921598d2..a0d5b930c2b 100644 --- a/ludwig/models/ecd.py +++ b/ludwig/models/ecd.py @@ -6,7 +6,7 @@ import torch from ludwig.combiners.combiners import create_combiner -from ludwig.constants import MODEL_ECD +from ludwig.constants import MODEL_ECD, MODEL_LLM from ludwig.globals import MODEL_WEIGHTS_FILE_NAME from ludwig.models.base import BaseModel from ludwig.schema.model_types.ecd import ECDModelConfig @@ -56,6 +56,18 @@ def __init__( # After constructing all layers, clear the cache to free up memory clear_data_cache() + def prepare_for_training(self): + # 1/10/23: For parity with how the LLM model type sets up adapters and quantization, LLM encoders should call + # `prepare_for_training` at training time rather than at initialization. This loop searches for input features + # using the LLM encoder and calls `prepare_for_training` on those encoders only. No other changes should be + # made to the ECD model itself or any other encoders. + for feature in self.config_obj.input_features: + encoder_type = feature.encoder.type + if encoder_type == MODEL_LLM: + feature_name = feature.name + encoder = self.input_features.get(feature_name) + encoder.prepare_for_training() + def encode( self, inputs: Union[ diff --git a/tests/ludwig/encoders/test_llm_encoders.py b/tests/ludwig/encoders/test_llm_encoders.py index 8fbcb211544..8b5f6b1faee 100644 --- a/tests/ludwig/encoders/test_llm_encoders.py +++ b/tests/ludwig/encoders/test_llm_encoders.py @@ -83,19 +83,37 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config): @pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys())) def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, model_config): - encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) - - # Test initializing with an adapter from peft import PeftModel + encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) encoder = LLMEncoder(encoder_config=encoder_config_with_adapter) + # The adapter should not be initialized until `prepare_for_training` is called + assert not isinstance(encoder.model, PeftModel) + assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + assert encoder.model_name == encoder_config.base_model - assert isinstance(encoder.model, PeftModel) - assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) # Check adapter was initialized assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length]) assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size]) + @pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys())) + def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: str): + from peft import PeftModel + + encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) + encoder = LLMEncoder(encoder_config=encoder_config_with_adapter) + + # The adapter should not be initialized until `prepare_for_training` is called + assert not isinstance(encoder.model, PeftModel) + assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + + # Initialize the adapter + encoder.prepare_for_training() + + # At this point, the adapter should be initialized and the state dict should contain adapter parameters + assert isinstance(encoder.model, PeftModel) + assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) + def test_save_to_state_dict(self, encoder_config: LLMEncoderConfig, tmpdir): # With no adapter, the state dict should only contain the model parameters encoder = LLMEncoder(encoder_config=encoder_config) @@ -106,6 +124,8 @@ def test_save_to_state_dict_adapter(self, encoder_config: LLMEncoderConfig, adap # With an adapter, the state dict should only contain adapter parameters encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter) encoder = LLMEncoder(encoder_config=encoder_config_with_adapter) + # Initialize the adapters + encoder.prepare_for_training() assert all(map(lambda k: "lora_" in k, encoder.state_dict().keys())) @pytest.mark.parametrize("wrap", [False, True], ids=["no_wrapper", "with_wrapper"]) @@ -151,6 +171,10 @@ def weights_init(m): encoder1 = LLMEncoder(encoder_config=encoder_config_with_adapter) encoder2 = LLMEncoder(encoder_config=encoder_config_with_adapter) + # Initialize the adapters + encoder1.prepare_for_training() + encoder2.prepare_for_training() + if wrap: encoder1 = WrapperModule(encoder1) encoder2 = WrapperModule(encoder2)