diff --git a/src/peft/tuners/xlora/model.py b/src/peft/tuners/xlora/model.py index 6074d684f2..6d36aa01d7 100644 --- a/src/peft/tuners/xlora/model.py +++ b/src/peft/tuners/xlora/model.py @@ -327,12 +327,11 @@ def scalings_injection_hook(target, args, kwargs, scalings): kwargs["scalings"] = scalings return args, kwargs - handles_to_remove = None - - def pre_forward(module, *args, **kwargs): - nonlocal handles_to_remove + hook_handles = [] + def _pre_forward(module, *args, **kwargs): # =========================== Forward pass with "dummy" scalings ================== + nonlocal hook_handles args_real = args[0] kwargs_real = args[1] @@ -340,10 +339,15 @@ def pre_forward(module, *args, **kwargs): dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real) - hook_handles = [] for module in self.modules(): if isinstance(module, LoraLayer): pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings) + existing_hooks = getattr(module, "_forward_pre_hooks", {}) + if any(val is scalings_injection_hook for val in existing_hooks.values()): + # When calling generate, module.forward is called multiple times inside the forward hook + # context, resulting in multiple hooks being registered. Therefore, we check if the hooks is + # already present and skip it in that case. + continue handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) @@ -374,19 +378,17 @@ def pre_forward(module, *args, **kwargs): handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) hook_handles.append(handle) - handles_to_remove = hook_handles - if not self.disabled: - forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True) + forward_handle = self.lora_model.model.register_forward_pre_hook(_pre_forward, with_kwargs=True) # Run the forward pass: first the scaling pass in the hook, and then with the base model - yield - - if not self.disabled: - # TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks. - for handle in handles_to_remove: - handle.remove() - forward_handle.remove() + try: + yield + finally: + if not self.disabled: + for handle in hook_handles: + handle.remove() + forward_handle.remove() def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" diff --git a/tests/test_xlora.py b/tests/test_xlora.py index 6ee097906f..4a681ef2e9 100644 --- a/tests/test_xlora.py +++ b/tests/test_xlora.py @@ -157,6 +157,19 @@ def test_functional(self, tokenizer, model): ) assert torch.isfinite(outputs[: inputs.shape[1] :]).all() + def test_forward_hooks_are_cleaned_up(self, tokenizer, model): + # There was an issue that forward hooks would accumulate during generation, since one hook per forward step was + # being registered and generate would call forward multiple times. This is already undesirable, but to make it + # worse, only the last hook was removed, resulting in hooks accumulating. + # See https://github.com/huggingface/peft/issues/1472#issuecomment-3235817807 + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + model.generate(input_ids=inputs.to(self.torch_device), max_new_tokens=10) + num_hooks_gen1 = len(model.base_model.model.model.decoder.layers[0].self_attn.k_proj._forward_pre_hooks) + + model.generate(input_ids=inputs.to(self.torch_device), max_new_tokens=10) + num_hooks_gen2 = len(model.base_model.model.model.decoder.layers[0].self_attn.k_proj._forward_pre_hooks) + assert num_hooks_gen1 == num_hooks_gen2 == 0 + def test_scalings_logging_methods(self, tokenizer, model): model.enable_scalings_logging()