这是indexloc提供的服务,不要输入任何密码
Skip to content

Conversation

@BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Sep 1, 2025

See #2759

Description

At the moment, we strongly couple the active adapter with requires_grad=True. Concretely, when we call model.set_adapter(name), we automatically assume that this adapter should not only be made active, its requires_grad should also be set to True.

For the purpose of training PEFT models, this is fair. However, when loading PEFT models for inference, this is not desired. Generally, for inference, we don't need requires_grad=True, but as is, it is enabled.

Generally, this is not a severe bug, since in the inference code, we don't perform any updates, thus we don't inadvertently update a weight because it wrongly has requires_grad=True -- this is probably why it went unnoticed so far. However, it could lead to worse runtime performance and memory overhead when PyTorch records grads for those parameters (which it shouldn't if called with torch.inference_mode, but some users may forget to use this). Therefore, this bug is still worth fixing.

Example

With modules_to_save

A very basic example where the current PEFT fails:

import os
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel, get_peft_model

model_id = "facebook/opt-125m"
path = "/tmp/peft/2759"
if not os.path.exists(path + "/adapter_model.safetensors"):
    model = AutoModelForCausalLM.from_pretrained(model_id)
    config = LoraConfig(target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"], r=8)
    model = get_peft_model(model, config)
    model.save_pretrained(path)
    del model

model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, path)
assert not model.base_model.model.lm_head.modules_to_save.default.weight.requires_grad

modules_to_save should not have grads enabled, but currently it does.

With multiple adapters

There is also an issue when loading more than one adapter:

model = PeftModel.from_pretrained(...)
assert not any(p.requires_grad for p in model.parameters())  # works

So far, so good, the first adapter does not have requires_grad.

model.load_adapter(...)
assert not any(p.requires_grad for p in model.parameters())  # fails

The load_adapter call inadvertently sets requires_grad=True for the weights of the first adapter. The reason why this happens is because when the second adapter is loaded, we call set_adapter with the first adapter to ensure that it remains the activate adapter. However, due to the coupling of active adapter and requires_grad, this would result in setting requires_grad=True for the first adapter.

The PR relaxes this coupling by allowing to call set_adapter with an additional argument, inference_mode. If set to True, the requires_grad will not be enabled, even if the adapter is activated.

The example above would also fail for modules_to_save and trainable_token_indices, not only for the LoRA/LoHa/... weights.

Still open bugs

The proposed solution is unfortunately not perfect. Right now, we do pass inference_mode based on the PEFT config of the adapter being added, which helps with the original issue described above. However, even this is not absolutely correct, because inference_mode of the second adapter does not necessarily have the same value as inference_mode of the first adapter. To illustrate how this can go wrong, I added an xfailing test:

test_loading_model_requires_grad_set_correctly_switch_inference_mode

I believe that this use case is rarer than the ones described at the beginning, so IMO it is okay to have this bug because we fix more common bugs. However, LMK if you disagree.

Related to this, I noticed that many tests in test_custom_models.TestRequiresGrad had code like this:

config0 = FooConfig(...)
peft_model = get_peft_model(MLP(), config0)
config1 = FooConfig(..., inference_mode=True)  # <==
peft_model.add_adapter("adapter1", config1)

This now fails because of the reason just given. I removed inference_mode=True here and the tests pass again.

Note that the only reason why inference_mode=True was passed here is because AdaLoRA cannot load 2 adapters in training mode and thus requires this. Later PEFT methods without this restriction blindly copied the AdaLoRA test. For those PEFT methods, I removed inference_mode=True to make them pass.

However, this also means that the AdaLoRA tests now fail. I thus marked them as xfail.

To properly fix this bug, I think we would have to refactor the code to isolate set_adapter (i.e. determining the active adapter) and setting requires_grad into separate code paths, as they're orthogonal. Moreover, these attributes are being set all over the place, which makes it hard to reason about where these attributes are being changed. This should be streamlined.

Making these changes while not breaking any existing code is not trivial (or maybe impossible even). Therefore, I went the easier way for the time being with this PR. Maybe a bigger refactor could be envisioned for a version 1.0 release of PEFT.

Related changes

While working on this, I noticed that LNTuning was buggy when calling set_adapter. This is now fixed.

Moreover, since I had to touch update_layer everywhere, I ensured that they all take kwargs for consistency.

Note to maintainers

  • Most changes in this PR are just the same updates for set_adapter and update_layer in each PEFT method's layer.py and model.py (except for prompt learning) with some diff noise due to updating type annotations and docstrings. For the review, focus on the changes in other.py, peft_model.py, tuners_utils.py, and test_custom_models.py.
  • If/When this PR is merged, existing PRs that add new PEFT methods have to be updated to reflect the changes.
  • Yes, it is about time we update the abstractions so that these types of changes become easier in the future (not having to update each PEFT method individually).

Description

At the moment, we strongly couple the active adapter with
requires_grad=True. Concretely, when we call model.set_adapter(name), we
automatically assume that this adapter should not only be made active,
its requires_grad should also be set to True.

For the purpose of training PEFT models, this is fair. However, when
loading PEFT models for inference, this is not desired. Generally, for
inference, we don't need requires_grad=True, but as is, it is enabled.

Generally, this is not a severe bug, since in the inference code, we
don't perform any updates, thus we don't inadvertently update a weight
because it wrongly has requires_grad=True. However, it could lead to
worse runtime performance and memory overhead when PyTorch records grads
for those parameters (which it shouldn't if called with
torch.inference_mode, but some users may forget to use this). Therefore,
this bug is still worth fixing.

Example

A very basic example where the current PEFT fails:

model = PeftModel.from_pretrained(...)
assert not any(p.requires_grad for p in model.parameters())  # works

So far, so good, the first adapter does not have requires_grad.

model.load_adapter(...)
assert not any(p.requires_grad for p in model.parameters())  # fails

The load_adapter call inadvertently sets requires_grad=True for the
weights of the _first_ adapter. The reason why this happens is because
when the second adapter is loaded, we call set_adapter with the first
adapter to ensure that it remains the activate adapter. However, due to
the coupling of active adapter and requires_grad, this would result in
setting requires_grad=True for the first adapter.

The PR relaxes this coupling by allowing to call set_adapter with an
additional argument, inference_mode. If set to True, the requires_grad
will not be enabled, even if the adapter is activated.

The example above would also fail for modules_to_save and trainable
tokens, not only for the LoRA/LoHa/... weights.

Still open bugs

The proposed solution is unfortunately not perfect. Right now, we do
pass inference_mode based on the PEFT config of the adapter being added,
which helps with the original issue described above. However, even this
is not absolutely correct, because inference_mode of the second adapter
does not necessarily have the same value as inference_mode of the first
adapter. To illustrate how this can go wrong, I added an xfailing test:

test_loading_model_requires_grad_set_correctly_switch_inference_mode

I believe that this use case is rarer than the one described at the
beginning, so IMO it is okay to have this bug because we fix a more
common bug. However, LMK if you disagree.

Related to this, I noticed that many tests in
test_custom_models.TestRequiresGrad had code like this:

config0 = FooConfig(...)
peft_model = get_peft_model(MLP(), config0)
config1 = FooConfig(..., inference_mode=True)  # <==
peft_model.add_adapter("adapter1", config1)

This now fails because of the reason just given. I removed
inference_mode=True here and the tests pass again.

Note that the only reason why inference_mode=True was passed here is
because AdaLoRA cannot load 2 adapters in training mode and thus
requires this. Later PEFT methods without this restriction blindly
copied the AdaLoRA test. For those PEFT methods, I removed
inference_mode=True.

However, this also means that the AdaLoRA tests now fail. I thus marked
them as xfail.

To properly fix this bug, I think we would have to refactor the code to
isolate set_adapter (i.e. determining the active adapter) and setting
requires_grad into separate code paths, as they're orthogonal. Moreover,
these attributes are being set all over the place, which makes it hard
to reason about where these attributes are being changed. This should be
streamlined.

Making these changes while not breaking any existing code is not
trivial (or maybe impossible even). Therefore, I went the easier way for
the time being with this PR. Maybe a bigger refactor could be envisioned
for a version 1.0 release of PEFT.

Related changes

While working on this, I noticed that LNTuning was completely buggy when
calling set_adapter. This is now fixed.

Moreover, since I had to touch update_layer everywhere, I ensured that
they all take kwargs for consistency.

Note to maintainers:

- If/When this PR is merged, existing PRs that add new PEFT methods have
to be updated to reflect the changes.
- Yes, it is about time we update the abstractions so that these types
of changes become easier in the future (not having to update each PEFT
method individually).
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit, otherwise LGTM. Thanks for taking care of this!

extra_kwargs = {}
if config_cls == IA3Config:
extra_kwargs["feedforward_modules"] = []
# targeting the different modules with modules_to_save:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# targeting the different modules with modules_to_save:

@BenjaminBossan BenjaminBossan merged commit 13fa0ae into huggingface:main Sep 8, 2025
14 checks passed
@BenjaminBossan BenjaminBossan deleted the fix-set-adapter-coupling-with-requires-grad branch September 8, 2025 17:49
@BenjaminBossan BenjaminBossan mentioned this pull request Sep 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants