-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow user to specify huggingface link or local path to pretrained lora weights #3572
Changes from 1 commit
4ee94bd
e242ce8
1c00724
650e5e6
c18e9e8
d92190a
1fbd1b2
a8b674f
d3d66ef
4578552
4c63915
ac439cd
9b94d91
9e5dfec
f6695bb
4355353
1ab20cc
effcce0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -221,12 +221,36 @@ def initialize_adapter(self): | |
"`finetune` or remove the adapter config." | ||
) | ||
|
||
from peft import get_peft_model, TaskType | ||
from peft import get_peft_model | ||
|
||
pretrained = False | ||
if self.config_obj.adapter.pretrained_weights: | ||
print(f"PRETRAINED_WEIGHTS: {self.config_obj.adapter.pretrained_weights}") | ||
jeffkinnison marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# If pretrained adapter weights are provided, we want to load them into the model | ||
from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig | ||
|
||
pretrained = True | ||
peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_weights) | ||
peft_dict = peft_config.to_dict() | ||
for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items(): | ||
jeffkinnison marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if param_name is None: | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're correct. This should be param_value. |
||
|
||
if param_name not in peft_dict: | ||
setattr(peft_config, param_name, param_value) | ||
|
||
self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained( | ||
self.model, self.config_obj.adapter.pretrained_weights | ||
) | ||
else: | ||
# If no pretrained adapter is provided, we want to load untrained weights into the model | ||
from peft import TaskType | ||
|
||
peft_config = self.config_obj.adapter.to_config( | ||
task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name | ||
) | ||
self.model = get_peft_model(self.model, peft_config) | ||
peft_config = self.config_obj.adapter.to_config( | ||
task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name | ||
) | ||
|
||
self.model = get_peft_model(self.model, peft_config, pretrained=pretrained) | ||
|
||
logger.info("==================================================") | ||
logger.info("Trainable Parameter Summary For Fine-Tuning") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,18 @@ class LoraConfig(BaseAdapterConfig): | |
description="Bias type for Lora.", | ||
) | ||
|
||
pretrained_weights: Optional[str] = schema_utils.String( | ||
default="none", | ||
jeffkinnison marked this conversation as resolved.
Show resolved
Hide resolved
|
||
description="Path to pretrained weights for Lora.", | ||
) | ||
|
||
target_modules: Optional[list] = schema_utils.List( | ||
str, | ||
default=None, | ||
allow_none=True, | ||
description="List of modules to apply Lora to. If None, apply to all modules.", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recall this causing an error if this wasn't set. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it! I would be good to know what the error was exactly so we can understand it and also leave a comment to explain it - might be useful when we come back to it in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I recall correctly, there was an error involving target_modules not being a parameter of a LoraConfig. |
||
|
||
def to_config(self, task_type: str = None, **kwargs) -> "PeftConfig": | ||
from peft import LoraConfig as _LoraConfig | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
MODEL_TYPE, | ||
OUTPUT_FEATURES, | ||
PREPROCESSING, | ||
PRETRAINED_WEIGHTS, | ||
PROMPT, | ||
TRAINER, | ||
TYPE, | ||
|
@@ -481,6 +482,36 @@ def test_llama_rope_scaling(): | |
assert model.model.config.rope_scaling["factor"] == 2.0 | ||
|
||
|
||
def test_load_pretrained_adapter_weights(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A couple of tests we should add (possibly in a followup PR):
|
||
from peft import PeftModel | ||
from transformers import PreTrainedModel | ||
|
||
config = { | ||
MODEL_TYPE: MODEL_LLM, | ||
BASE_MODEL: TEST_MODEL_NAME, | ||
INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], | ||
OUTPUT_FEATURES: [text_feature(name="output")], | ||
TRAINER: { | ||
TYPE: "finetune", | ||
BATCH_SIZE: 8, | ||
EPOCHS: 2, | ||
}, | ||
ADAPTER: {TYPE: "lora", PRETRAINED_WEIGHTS: "Infernaught/test_adapter_weights"}, | ||
BACKEND: {TYPE: "local"}, | ||
} | ||
|
||
print(ModelConfig) | ||
config_obj = ModelConfig.from_dict(config) | ||
model = LLM(config_obj) | ||
|
||
assert model.config_obj.adapter.pretrained_weights | ||
assert model.config_obj.adapter.pretrained_weights == "Infernaught/test_adapter_weights" | ||
|
||
model.prepare_for_training() | ||
assert not isinstance(model.model, PreTrainedModel) | ||
assert isinstance(model.model, PeftModel) | ||
|
||
|
||
def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: | ||
# Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 | ||
for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, might be more clear to call it
pretrained_adapter_weights
since pretrained weights also come from the model! So just to avoid confusionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On it