diff --git a/ludwig/api.py b/ludwig/api.py index 942034cfc1c..6abe4068831 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -615,6 +615,14 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): # auto tune batch size self._tune_batch_size(trainer, training_set, random_seed=random_seed) + if ( + self.config_obj.model_type == "LLM" + and trainer.config.type == "none" + and self.config_obj.adapter is not None + and self.config_obj.adapter.pretrained_adapter_weights is not None + ): + trainer.model.initialize_adapter() # Load pre-trained adapter weights for inference only + # train model if self.backend.is_coordinator(): print_boxed("TRAINING") diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 1621ee4809f..b873c1087a2 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -494,6 +494,14 @@ def check_llm_finetuning_trainer_config(config: "ModelConfig"): # noqa: F821 if config.model_type != MODEL_LLM: return + if ( + config.trainer.type == "none" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None + ): + # If performing zero-shot, we must specify pretrained adapter weights + return + if config.adapter is not None and config.trainer.type != "finetune": raise ConfigValidationError("LLM finetuning requires trainer type to be finetune.") @@ -509,7 +517,11 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 return # LLM finetuning is only supported by the finetune trainer type - if config.trainer.type != "finetune": + if ( + config.trainer.type != "finetune" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None + ): return # Using local backend, so skip the checks below @@ -529,9 +541,8 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 def check_llm_finetuning_adalora_config(config: "ModelConfig"): """Checks that the adalora adapter is configured correctly. - It requires a set of target_modules to be specified in the config for the model. If it isn't specified by the user, - we also check against PEFT's predefined target module list for ADALORA to see if this key is present there. If - neither is true, AdaloraModel will run into issues downstream. + We check against PEFT's predefined target module list for ADALORA to see if this target_modules is present there. If + not, AdaloraModel will run into issues downstream. """ if config.model_type != MODEL_LLM: return @@ -545,10 +556,7 @@ def check_llm_finetuning_adalora_config(config: "ModelConfig"): from peft.utils import TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING model_config = _get_llm_model_config(config.base_model) - if ( - not config.adapter.target_modules - and model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING - ): + if model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: raise ConfigValidationError( f"Adalora adapter is not supported for {model_config.model_type} model. " f"Supported model types are: {list(TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING.keys())}. " diff --git a/ludwig/constants.py b/ludwig/constants.py index 226a60a4ed6..d2cc455df24 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -282,6 +282,7 @@ GENERATION = "generation" PROMPT = "prompt" ADAPTER = "adapter" +PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights" # CrossEntropyLoss for LLMs IGNORE_INDEX_TOKEN_ID = -100 diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index fc7319e3114..408d8c89cf4 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -216,18 +216,51 @@ def output_feature_decoder(self) -> OutputFeature: def initialize_adapter(self): """If an adapter config is provided, we want to wrap the model with a PEFT model for fine-tuning.""" if self.config_obj.adapter: - if self.config_obj.trainer.type != "finetune": + if self.config_obj.trainer.type != "finetune" and not self.config_obj.adapter.pretrained_adapter_weights: raise ValueError( "Adapter config was provided, but trainer type is not set to `finetune`. Either set the trainer to " "`finetune` or remove the adapter config." ) - from peft import get_peft_model, TaskType + from peft import get_peft_model + + if self.config_obj.adapter.pretrained_adapter_weights: + logger.info(f"Using pretrained adapter weights: {self.config_obj.adapter.pretrained_adapter_weights}") + # If pretrained adapter weights are provided, we want to load them into the model + from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig + + peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_adapter_weights) + peft_dict = peft_config.to_dict() + + # Need to update the peft config with some of the values from config_obj because not all of them are set + for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items(): + # Not all parameters are supported by all models, so we only add the parameter to the load kwargs + # if it is supported by the model. + if param_value is None: + # param_name and param_value come from the config object and contain default + # values for the adapter. Examples of parameters with missing values might be: + # 'auto_mapping', 'base_model_name_or_path', and 'task_type'. + # Note that some of these values might already be set in peft_config, which comes from HF + # directly (specifically, adapter_config.json in the model repo), and we don't want to override + # those values with None. + continue + if param_name not in peft_dict: + # If any parameters are not set in adapter_config.json in HF, we want to populate them with the + # appropriate default values. + setattr(peft_config, param_name, param_value) + + self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained( + self.model, self.config_obj.adapter.pretrained_adapter_weights + ) + else: + # If no pretrained adapter is provided, we want to load untrained weights into the model + from peft import TaskType - peft_config = self.config_obj.adapter.to_config( - task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name - ) - self.model = get_peft_model(self.model, peft_config) + peft_config = self.config_obj.adapter.to_config( + task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name + ) + + self.model = get_peft_model(self.model, peft_config) logger.info("==================================================") logger.info("Trainable Parameter Summary For Fine-Tuning") diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 6e127ee5fbf..3ce30aeb07c 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -30,6 +30,10 @@ def wrap(config: BaseAdapterConfig): class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): type: str + pretrained_adapter_weights: Optional[str] = schema_utils.String( + default=None, description="Path to pretrained weights.", allow_none=True + ) + @abstractmethod def to_config(self, **kwargs) -> "PeftConfig": pass @@ -359,7 +363,7 @@ def description(cls) -> str: @register_adapter("adaption_prompt") @ludwig_dataclass class AdaptionPromptConfig(BaseAdapterConfig): - """Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt.py.""" + """Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt/config.py.""" def __post_init__(self): if not self.adapter_len: diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index a10d7745b39..91c53c9a7e8 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -18,6 +18,7 @@ MODEL_TYPE, OUTPUT_FEATURES, PREPROCESSING, + PRETRAINED_ADAPTER_WEIGHTS, PROMPT, TRAINER, TYPE, @@ -492,12 +493,61 @@ def test_default_max_sequence_length(): BATCH_SIZE: 8, EPOCHS: 2, }, + ADAPTER: {TYPE: "lora", PRETRAINED_ADAPTER_WEIGHTS: "Infernaught/test_adapter_weights"}, + BACKEND: {TYPE: "local"}, } config_obj = ModelConfig.from_dict(config) assert config_obj.input_features[0].preprocessing.max_sequence_length is None assert config_obj.output_features[0].preprocessing.max_sequence_length is None +@pytest.mark.parametrize("adapter", ["lora", "adalora", "adaption_prompt"]) +def test_load_pretrained_adapter_weights(adapter): + from peft import PeftModel + from transformers import PreTrainedModel + + weights = "" + model = "" + if adapter == "lora": + weights = "Infernaught/test_adapter_weights" + base_model = TEST_MODEL_NAME + elif adapter == "adalora": + weights = "Infernaught/test_adalora_weights" + base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + elif adapter == "adaption_prompt": + weights = "Infernaught/test_ap_weights" + base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + else: + raise () + + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: base_model, + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + TRAINER: { + TYPE: "none", + BATCH_SIZE: 8, + EPOCHS: 2, + }, + ADAPTER: {TYPE: adapter, PRETRAINED_ADAPTER_WEIGHTS: weights}, + BACKEND: {TYPE: "local"}, + } + config_obj = ModelConfig.from_dict(config) + model = LLM(config_obj) + + assert model.config_obj.adapter.pretrained_adapter_weights + assert model.config_obj.adapter.pretrained_adapter_weights == weights + + model.prepare_for_training() + assert not isinstance(model.model, PreTrainedModel) + assert isinstance(model.model, PeftModel) + + config_obj = ModelConfig.from_dict(config) + assert config_obj.input_features[0].preprocessing.max_sequence_length is None + assert config_obj.output_features[0].preprocessing.max_sequence_length is None + + def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):