-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FIX Multiple issues with target_parameters #2710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Multiple issues with target_parameters #2710
Conversation
There are a few issues with target_parameters that are fixed in this PR. Existing parametrizations When using target_parameters with LoRA, after the forward call finishes, the LoRA parametrization is removed. However, this also used to remove all other parametrizations on the same parameter, which is bad. With this PR, only the LoRA parametrization is removed. Module repr This PR also extends the __repr__ of lora.ParamWrapper to contain the parameter name, which makes it more useful. Multiple LoRA adapters with target_parameters There is an issue when adding a second LoRA adapter with target_paramters, where this second adapter would not actually be applied correctly. The corresponding unit test was too lax to notice the bug. This is not easy to fix, so for now we forbid adding a second adapter with target_parameters. This is very strict but it's better than having silent errors. Although it was possible to fix that specific issue, the solution resulted in ever deeply nested adapters (i.e. with multiple .base_layer). This in turn results in those infixes to be part of the state_dict. But then we cannot load the individual adapters correctly, except if the model is restored in the exact same order as it was previously created. This is not normally a requirement in PEFT (e.g. I can create a model with two adapters and later decide to load only one of them). In the long run, we need to think about solutions that would allow this. It may require some form of normalization of the layers to prevent ever deeper nesting. Also, what is ugly right now is that, given that the LoRA lives on a module but actually targets one of possibly multiple parameter, the LoRA weights don't actually reference said parameter in any name. That means, purely from the state_dict, it is unclear which parameter a LoRA weight belongs to. Ideally, this should be encoded in the LoRA weight key.
|
@matthewdouglas The first part of the fixes in this PR should hopefully address the issue we discussed internally with bnb. Check out the tests that show this: https://github.com/huggingface/peft/pull/2710/files#diff-c8dcf5ce96401fd88057132c84ee4a25da3550574250c9a1857b8aa1672ea540R406 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
To better illustrate the existing issues with multiple adapters, check out this script: from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel, get_peft_model
from safetensors.torch import load_file
model_id = "trl-internal-testing/tiny-Llama4ForCausalLM"
path = "/tmp/peft/target-params"
model = AutoModelForCausalLM.from_pretrained(model_id)
config = LoraConfig(
target_modules=[],
target_parameters=[
"feed_forward.experts.down_proj",
"feed_forward.experts.gate_up_proj",
],
)
model = get_peft_model(model, config)
model.add_adapter("other", config)
print(model)
model.save_pretrained(path)
sd = load_file(path + "/adapter_model.safetensors")
print("default sd")
for k in sd:
print(k)
print("other sd")
sd_other = load_file(path + "/other/adapter_model.safetensors")
for k in sd:
print(k)
del model
# this works
print("loading in same order")
model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, path)
out = model.load_adapter(path + "/other", adapter_name="other")
print(out)
del model
print("loading in reverse order, should theoretically not matter!")
model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, path + "/other", adapter_name="other")
out = model.load_adapter(path, adapter_name="default")
print(out)Running this gives us: We create a LoRA model with two adapters that use
|
|
I'm investigating the failing CI. Update: Done, a new test was not implemented correctly. |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two small comments, otherwise LGTM. Thanks for handling the weirdness target params entail.
src/peft/tuners/lora/layer.py
Outdated
| rep = super().__repr__() | ||
| idx = rep.find("(") + 1 | ||
| # insert the name of the parameter | ||
| # insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same mdule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same mdule | |
| # insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same module |
src/peft/tuners/tuners_utils.py
Outdated
| self.targeted_parameter_names.append(key) | ||
| else: | ||
| # Standard case: the parameter is not already parametrized. Note, however, that the model could already | ||
| # be nested with lora.ParamWrapper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worthwhile adding how this comes to be (being already nested with ParamWrapper). I don't think that's obvious.
There are a few issues with target_parameters that are fixed in this PR. Existing parametrizations When using target_parameters with LoRA, after the forward call finishes, the LoRA parametrization is removed. However, this also used to remove all other parametrizations on the same parameter, which is bad. With this PR, only the LoRA parametrization is removed. Module repr This PR also extends the __repr__ of lora.ParamWrapper to contain the parameter name, which makes it more useful. Extend testing Added a tiny gpt-oss model to the target_parameters test suite. Multiple LoRA adapters with target_parameters There is an issue when adding a second LoRA adapter with target_paramters, where this second adapter would not actually be applied correctly. The corresponding unit test was too lax to notice the bug. This is not easy to fix, so for now we forbid adding a second adapter with target_parameters. This is very strict but it's better than having silent errors. Although it was possible to fix that specific issue, the solution resulted in ever deeply nested adapters (i.e. with multiple .base_layer). This in turn results in those infixes to be part of the state_dict. But then we cannot load the individual adapters correctly, except if the model is restored in the exact same order as it was previously created. This is not normally a requirement in PEFT (e.g. I can create a model with two adapters and later decide to load only one of them). In the long run, we need to think about solutions that would allow this. It may require some form of normalization of the layers to prevent ever deeper nesting. Also, what is ugly right now is that, given that the LoRA lives on a module but actually targets one of possibly multiple parameter, the LoRA weights don't actually reference said parameter in any name. That means, purely from the state_dict, it is unclear which parameter a LoRA weight belongs to. Ideally, this should be encoded in the LoRA weight key.
* FIX Multiple issues with target_parameters (#2710) * Bump version to 0.17.1
There are a few issues with target_parameters that are fixed in this PR.
Existing parametrizations
When using
target_parameterswith LoRA, after the forward call finishes, the LoRA parametrization is removed. However, this also used to remove all other parametrizations on the same parameter, which is bad. With this PR, only the LoRA parametrization is removed.Module
reprThis PR also extends the
__repr__oflora.ParamWrapperto contain the parameter name, which makes it more useful.Extend testing
Added a tiny gpt-oss model to the
target_parameterstest suite.Multiple LoRA adapters with
target_parametersThere is an issue when adding a second LoRA adapter with
target_paramters, where this second adapter would not actually be applied correctly. The corresponding unit test was too lax to notice the bug. This is not easy to fix, so for now we forbid adding a second adapter withtarget_parameters. This is very strict but it's better than having silent errors.Although it was possible to fix that specific issue, the solution resulted in ever deeply nested adapters (i.e. with multiple
.base_layer). This in turn results in those infixes to be part of thestate_dict. But then we cannot load the individual adapters correctly, except if the model is restored in the exact same order as it was previously created. This is not normally a requirement in PEFT (e.g. I can create a model with two adapters and later decide to load only one of them).In the long run, we need to think about solutions that would allow this. It may require some form of normalization of the layers to prevent ever deeper nesting. Also, what is ugly right now is that, given that the LoRA lives on a module but actually targets one of possibly multiple parameter, the LoRA weights don't actually reference said parameter in any name. That means, purely from the
state_dict, it is unclear which parameter a LoRA weight belongs to. Ideally, this should be encoded in the LoRA weight key.