From 00c9324f0640d834edb81b16de9826cb4e242a24 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 29 Aug 2025 11:56:05 +0200 Subject: [PATCH 1/2] FIX X-LoRA forward hook issue during generate There was an issue that forward hooks would accumulate during generation, since one hook per forward step was being registered and generate would call forward multiple times. This is already undesirable, but to make it worse, only the last hook was removed, resulting in hooks accumulating. This PR fixes the issue. See https://github.com/huggingface/peft/issues/1472#issuecomment-3235817807 --- src/peft/tuners/xlora/model.py | 20 +++++++++++--------- tests/test_xlora.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/xlora/model.py b/src/peft/tuners/xlora/model.py index 6074d684f2..3a148f66dd 100644 --- a/src/peft/tuners/xlora/model.py +++ b/src/peft/tuners/xlora/model.py @@ -327,12 +327,11 @@ def scalings_injection_hook(target, args, kwargs, scalings): kwargs["scalings"] = scalings return args, kwargs - handles_to_remove = None - - def pre_forward(module, *args, **kwargs): - nonlocal handles_to_remove + hook_handles = [] + def _pre_forward(module, *args, **kwargs): # =========================== Forward pass with "dummy" scalings ================== + nonlocal hook_handles args_real = args[0] kwargs_real = args[1] @@ -340,10 +339,15 @@ def pre_forward(module, *args, **kwargs): dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real) - hook_handles = [] for module in self.modules(): if isinstance(module, LoraLayer): pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings) + existing_hooks = getattr(module, "_forward_pre_hooks", {}) + if any(val is scalings_injection_hook for val in existing_hooks.values()): + # When calling generate, module.forward is called multiple times inside the forward hook + # context, resulting in multiple hooks being registered. Therefore, we check if the hooks is + # already present and skip it in that case. + continue handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) @@ -374,17 +378,15 @@ def pre_forward(module, *args, **kwargs): handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) - handles_to_remove = hook_handles - if not self.disabled: - forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True) + forward_handle = self.lora_model.model.register_forward_pre_hook(_pre_forward, with_kwargs=True) # Run the forward pass: first the scaling pass in the hook, and then with the base model yield if not self.disabled: # TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks. - for handle in handles_to_remove: + for handle in hook_handles: handle.remove() forward_handle.remove() diff --git a/tests/test_xlora.py b/tests/test_xlora.py index 6ee097906f..4a681ef2e9 100644 --- a/tests/test_xlora.py +++ b/tests/test_xlora.py @@ -157,6 +157,19 @@ def test_functional(self, tokenizer, model): ) assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + def test_forward_hooks_are_cleaned_up(self, tokenizer, model): + # There was an issue that forward hooks would accumulate during generation, since one hook per forward step was + # being registered and generate would call forward multiple times. This is already undesirable, but to make it + # worse, only the last hook was removed, resulting in hooks accumulating. + # See https://github.com/huggingface/peft/issues/1472#issuecomment-3235817807 + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + model.generate(input_ids=inputs.to(self.torch_device), max_new_tokens=10) + num_hooks_gen1 = len(model.base_model.model.model.decoder.layers[0].self_attn.k_proj._forward_pre_hooks) + + model.generate(input_ids=inputs.to(self.torch_device), max_new_tokens=10) + num_hooks_gen2 = len(model.base_model.model.model.decoder.layers[0].self_attn.k_proj._forward_pre_hooks) + assert num_hooks_gen1 == num_hooks_gen2 == 0 + def test_scalings_logging_methods(self, tokenizer, model): model.enable_scalings_logging() From 1fcee35419168e26b5e743f2f900c7b0f01ed239 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 1 Sep 2025 16:20:27 +0200 Subject: [PATCH 2/2] Reviewer comment: use try ... finally --- src/peft/tuners/xlora/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/peft/tuners/xlora/model.py b/src/peft/tuners/xlora/model.py index 3a148f66dd..6d36aa01d7 100644 --- a/src/peft/tuners/xlora/model.py +++ b/src/peft/tuners/xlora/model.py @@ -382,13 +382,13 @@ def _pre_forward(module, *args, **kwargs): forward_handle = self.lora_model.model.register_forward_pre_hook(_pre_forward, with_kwargs=True) # Run the forward pass: first the scaling pass in the hook, and then with the base model - yield - - if not self.disabled: - # TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks. - for handle in hook_handles: - handle.remove() - forward_handle.remove() + try: + yield + finally: + if not self.disabled: + for handle in hook_handles: + handle.remove() + forward_handle.remove() def __getattr__(self, name: str): """Forward missing attributes to the wrapped module."""