Skip to content

Commit

Permalink
Add custom prepare_for_trianing logic to ECD model for LLM encoder …
Browse files Browse the repository at this point in the history
…adapter initialization (#3874)
  • Loading branch information
jeffkinnison authored Jan 11, 2024
1 parent 22024d7 commit e7d86e4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
14 changes: 8 additions & 6 deletions ludwig/encoders/text_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,8 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):

self.config = encoder_config

self.adapter_is_initialized = False

self.model_name = self.config.base_model
self.model_config = AutoConfig.from_pretrained(self.config.base_model)

Expand All @@ -2421,8 +2423,6 @@ def __init__(self, encoder_config: LLMEncoderConfig = None, **kwargs):

self.attention_masks = None

self.prepare_for_training()

clear_data_cache()

@staticmethod
Expand Down Expand Up @@ -2451,6 +2451,8 @@ def initialize_adapter(self):
self.model.print_trainable_parameters()
logger.info("==================================================")

self.adapter_is_initialized = True

def prepare_for_training(self):
# TODO: this implementation will not work if resuming from a previous checkpoint. Need to fix this.
if self.config.quantization:
Expand Down Expand Up @@ -2485,7 +2487,7 @@ def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool):
# contents of the state_dict.
# The three args to this method are supplied by Module.state_dict
# https://github.com/pytorch/pytorch/blob/8739d1e3f9b08f4282fe79fc8dacd781d16913ff/torch/nn/modules/module.py#L1824
if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
# get_peft_model_state_dict geneates a state dict that only contains the adapter weights
from peft.utils.save_and_load import get_peft_model_state_dict

Expand All @@ -2498,7 +2500,7 @@ def _save_to_state_dict(self, destination: Dict, prefix: str, keep_vars: bool):
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
destination = super().state_dict(destination, prefix=prefix, keep_vars=keep_vars)

if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type]
exclude_model_keys = [k for k in destination.keys() if adapter_type_prefix not in k]

Expand All @@ -2518,7 +2520,7 @@ def _load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
# When using an adapter, only the adapter weights are saved, and so we only want to load those weights.
# Under the hood, PEFT alters the names of the parameters, which leads to an "unexpected keys" error when
# using strict mode. This block uses PEFT's version of `load_state_dict` to handle loading in weights.
Expand All @@ -2540,7 +2542,7 @@ def remove_missing_non_adapter_keys(self, module, incompatible_keys):
"""
# If no adapter was used, `LLMEncoder.load_state_dict` should use the default `torch.Module.load_state_dict`
# code path to load weights and no modification should be necessary.
if self.config.adapter:
if self.config.adapter and self.adapter_is_initialized:
adapter_type_prefix = self.ADAPTER_PARAM_NAME_PREFIX[self.config.adapter.type]
missing_keys, unexpected_keys = incompatible_keys

Expand Down
14 changes: 13 additions & 1 deletion ludwig/models/ecd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from ludwig.combiners.combiners import create_combiner
from ludwig.constants import MODEL_ECD
from ludwig.constants import MODEL_ECD, MODEL_LLM
from ludwig.globals import MODEL_WEIGHTS_FILE_NAME
from ludwig.models.base import BaseModel
from ludwig.schema.model_types.ecd import ECDModelConfig
Expand Down Expand Up @@ -56,6 +56,18 @@ def __init__(
# After constructing all layers, clear the cache to free up memory
clear_data_cache()

def prepare_for_training(self):
# 1/10/23: For parity with how the LLM model type sets up adapters and quantization, LLM encoders should call
# `prepare_for_training` at training time rather than at initialization. This loop searches for input features
# using the LLM encoder and calls `prepare_for_training` on those encoders only. No other changes should be
# made to the ECD model itself or any other encoders.
for feature in self.config_obj.input_features:
encoder_type = feature.encoder.type
if encoder_type == MODEL_LLM:
feature_name = feature.name
encoder = self.input_features.get(feature_name)
encoder.prepare_for_training()

def encode(
self,
inputs: Union[
Expand Down
34 changes: 29 additions & 5 deletions tests/ludwig/encoders/test_llm_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,37 @@ def test_init(self, encoder_config: LLMEncoderConfig, model_config):

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_init_with_adapter(self, encoder_config: LLMEncoderConfig, adapter: str, model_config):
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)

# Test initializing with an adapter
from peft import PeftModel

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

assert encoder.model_name == encoder_config.base_model
assert isinstance(encoder.model, PeftModel)
assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys())) # Check adapter was initialized
assert encoder.input_shape == torch.Size([encoder_config.max_sequence_length])
assert encoder.output_shape == torch.Size([encoder_config.max_sequence_length, model_config.hidden_size])

@pytest.mark.parametrize("adapter", list(ADAPTER_CONFIG_MAP.keys()))
def test_prepare_for_training(self, encoder_config: LLMEncoderConfig, adapter: str):
from peft import PeftModel

encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)

# The adapter should not be initialized until `prepare_for_training` is called
assert not isinstance(encoder.model, PeftModel)
assert not any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

# Initialize the adapter
encoder.prepare_for_training()

# At this point, the adapter should be initialized and the state dict should contain adapter parameters
assert isinstance(encoder.model, PeftModel)
assert any(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

def test_save_to_state_dict(self, encoder_config: LLMEncoderConfig, tmpdir):
# With no adapter, the state dict should only contain the model parameters
encoder = LLMEncoder(encoder_config=encoder_config)
Expand All @@ -106,6 +124,8 @@ def test_save_to_state_dict_adapter(self, encoder_config: LLMEncoderConfig, adap
# With an adapter, the state dict should only contain adapter parameters
encoder_config_with_adapter = self.create_encoder_config_with_adapter(encoder_config, adapter)
encoder = LLMEncoder(encoder_config=encoder_config_with_adapter)
# Initialize the adapters
encoder.prepare_for_training()
assert all(map(lambda k: "lora_" in k, encoder.state_dict().keys()))

@pytest.mark.parametrize("wrap", [False, True], ids=["no_wrapper", "with_wrapper"])
Expand Down Expand Up @@ -151,6 +171,10 @@ def weights_init(m):
encoder1 = LLMEncoder(encoder_config=encoder_config_with_adapter)
encoder2 = LLMEncoder(encoder_config=encoder_config_with_adapter)

# Initialize the adapters
encoder1.prepare_for_training()
encoder2.prepare_for_training()

if wrap:
encoder1 = WrapperModule(encoder1)
encoder2 = WrapperModule(encoder2)
Expand Down

0 comments on commit e7d86e4

Please sign in to comment.