-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FIX: Wrong coupling between requires_grad and the active adapter #2765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX: Wrong coupling between requires_grad and the active adapter #2765
Conversation
Description
At the moment, we strongly couple the active adapter with
requires_grad=True. Concretely, when we call model.set_adapter(name), we
automatically assume that this adapter should not only be made active,
its requires_grad should also be set to True.
For the purpose of training PEFT models, this is fair. However, when
loading PEFT models for inference, this is not desired. Generally, for
inference, we don't need requires_grad=True, but as is, it is enabled.
Generally, this is not a severe bug, since in the inference code, we
don't perform any updates, thus we don't inadvertently update a weight
because it wrongly has requires_grad=True. However, it could lead to
worse runtime performance and memory overhead when PyTorch records grads
for those parameters (which it shouldn't if called with
torch.inference_mode, but some users may forget to use this). Therefore,
this bug is still worth fixing.
Example
A very basic example where the current PEFT fails:
model = PeftModel.from_pretrained(...)
assert not any(p.requires_grad for p in model.parameters()) # works
So far, so good, the first adapter does not have requires_grad.
model.load_adapter(...)
assert not any(p.requires_grad for p in model.parameters()) # fails
The load_adapter call inadvertently sets requires_grad=True for the
weights of the _first_ adapter. The reason why this happens is because
when the second adapter is loaded, we call set_adapter with the first
adapter to ensure that it remains the activate adapter. However, due to
the coupling of active adapter and requires_grad, this would result in
setting requires_grad=True for the first adapter.
The PR relaxes this coupling by allowing to call set_adapter with an
additional argument, inference_mode. If set to True, the requires_grad
will not be enabled, even if the adapter is activated.
The example above would also fail for modules_to_save and trainable
tokens, not only for the LoRA/LoHa/... weights.
Still open bugs
The proposed solution is unfortunately not perfect. Right now, we do
pass inference_mode based on the PEFT config of the adapter being added,
which helps with the original issue described above. However, even this
is not absolutely correct, because inference_mode of the second adapter
does not necessarily have the same value as inference_mode of the first
adapter. To illustrate how this can go wrong, I added an xfailing test:
test_loading_model_requires_grad_set_correctly_switch_inference_mode
I believe that this use case is rarer than the one described at the
beginning, so IMO it is okay to have this bug because we fix a more
common bug. However, LMK if you disagree.
Related to this, I noticed that many tests in
test_custom_models.TestRequiresGrad had code like this:
config0 = FooConfig(...)
peft_model = get_peft_model(MLP(), config0)
config1 = FooConfig(..., inference_mode=True) # <==
peft_model.add_adapter("adapter1", config1)
This now fails because of the reason just given. I removed
inference_mode=True here and the tests pass again.
Note that the only reason why inference_mode=True was passed here is
because AdaLoRA cannot load 2 adapters in training mode and thus
requires this. Later PEFT methods without this restriction blindly
copied the AdaLoRA test. For those PEFT methods, I removed
inference_mode=True.
However, this also means that the AdaLoRA tests now fail. I thus marked
them as xfail.
To properly fix this bug, I think we would have to refactor the code to
isolate set_adapter (i.e. determining the active adapter) and setting
requires_grad into separate code paths, as they're orthogonal. Moreover,
these attributes are being set all over the place, which makes it hard
to reason about where these attributes are being changed. This should be
streamlined.
Making these changes while not breaking any existing code is not
trivial (or maybe impossible even). Therefore, I went the easier way for
the time being with this PR. Maybe a bigger refactor could be envisioned
for a version 1.0 release of PEFT.
Related changes
While working on this, I noticed that LNTuning was completely buggy when
calling set_adapter. This is now fixed.
Moreover, since I had to touch update_layer everywhere, I ensured that
they all take kwargs for consistency.
Note to maintainers:
- If/When this PR is merged, existing PRs that add new PEFT methods have
to be updated to reflect the changes.
- Yes, it is about time we update the abstractions so that these types
of changes become easier in the future (not having to update each PEFT
method individually).
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
githubnemo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nit, otherwise LGTM. Thanks for taking care of this!
tests/test_custom_models.py
Outdated
| extra_kwargs = {} | ||
| if config_cls == IA3Config: | ||
| extra_kwargs["feedforward_modules"] = [] | ||
| # targeting the different modules with modules_to_save: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # targeting the different modules with modules_to_save: |
See #2759
Description
At the moment, we strongly couple the active adapter with
requires_grad=True. Concretely, when we callmodel.set_adapter(name), we automatically assume that this adapter should not only be made active, itsrequires_gradshould also be set toTrue.For the purpose of training PEFT models, this is fair. However, when loading PEFT models for inference, this is not desired. Generally, for inference, we don't need
requires_grad=True, but as is, it is enabled.Generally, this is not a severe bug, since in the inference code, we don't perform any updates, thus we don't inadvertently update a weight because it wrongly has
requires_grad=True-- this is probably why it went unnoticed so far. However, it could lead to worse runtime performance and memory overhead when PyTorch records grads for those parameters (which it shouldn't if called withtorch.inference_mode, but some users may forget to use this). Therefore, this bug is still worth fixing.Example
With
modules_to_saveA very basic example where the current PEFT fails:
modules_to_saveshould not have grads enabled, but currently it does.With multiple adapters
There is also an issue when loading more than one adapter:
So far, so good, the first adapter does not have
requires_grad.The
load_adaptercall inadvertently setsrequires_grad=Truefor the weights of the first adapter. The reason why this happens is because when the second adapter is loaded, we callset_adapterwith the first adapter to ensure that it remains the activate adapter. However, due to the coupling of active adapter andrequires_grad, this would result in settingrequires_grad=Truefor the first adapter.The PR relaxes this coupling by allowing to call
set_adapterwith an additional argument,inference_mode. If set toTrue, therequires_gradwill not be enabled, even if the adapter is activated.The example above would also fail for
modules_to_saveandtrainable_token_indices, not only for the LoRA/LoHa/... weights.Still open bugs
The proposed solution is unfortunately not perfect. Right now, we do pass
inference_modebased on the PEFT config of the adapter being added, which helps with the original issue described above. However, even this is not absolutely correct, becauseinference_modeof the second adapter does not necessarily have the same value asinference_modeof the first adapter. To illustrate how this can go wrong, I added an xfailing test:test_loading_model_requires_grad_set_correctly_switch_inference_modeI believe that this use case is rarer than the ones described at the beginning, so IMO it is okay to have this bug because we fix more common bugs. However, LMK if you disagree.
Related to this, I noticed that many tests in
test_custom_models.TestRequiresGradhad code like this:This now fails because of the reason just given. I removed
inference_mode=Truehere and the tests pass again.Note that the only reason why
inference_mode=Truewas passed here is because AdaLoRA cannot load 2 adapters in training mode and thus requires this. Later PEFT methods without this restriction blindly copied the AdaLoRA test. For those PEFT methods, I removedinference_mode=Trueto make them pass.However, this also means that the AdaLoRA tests now fail. I thus marked them as xfail.
To properly fix this bug, I think we would have to refactor the code to isolate
set_adapter(i.e. determining the active adapter) and settingrequires_gradinto separate code paths, as they're orthogonal. Moreover, these attributes are being set all over the place, which makes it hard to reason about where these attributes are being changed. This should be streamlined.Making these changes while not breaking any existing code is not trivial (or maybe impossible even). Therefore, I went the easier way for the time being with this PR. Maybe a bigger refactor could be envisioned for a version 1.0 release of PEFT.
Related changes
While working on this, I noticed that LNTuning was buggy when calling
set_adapter. This is now fixed.Moreover, since I had to touch
update_layereverywhere, I ensured that they all takekwargsfor consistency.Note to maintainers
set_adapterandupdate_layerin each PEFT method'slayer.pyandmodel.py(except for prompt learning) with some diff noise due to updating type annotations and docstrings. For the review, focus on the changes inother.py,peft_model.py,tuners_utils.py, andtest_custom_models.py.