这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions src/peft/tuners/xlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,23 +327,27 @@ def scalings_injection_hook(target, args, kwargs, scalings):
kwargs["scalings"] = scalings
return args, kwargs

handles_to_remove = None

def pre_forward(module, *args, **kwargs):
nonlocal handles_to_remove
hook_handles = []

def _pre_forward(module, *args, **kwargs):
# =========================== Forward pass with "dummy" scalings ==================
nonlocal hook_handles

args_real = args[0]
kwargs_real = args[1]
kwargs_real.update(kwargs)

dummy_scalings = self.internal_xlora_classifier.make_dummy_scalings(*args_real, **kwargs_real)

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer):
pre_forward = partial(scalings_injection_hook, scalings=dummy_scalings)
existing_hooks = getattr(module, "_forward_pre_hooks", {})
if any(val is scalings_injection_hook for val in existing_hooks.values()):
# When calling generate, module.forward is called multiple times inside the forward hook
# context, resulting in multiple hooks being registered. Therefore, we check if the hooks is
# already present and skip it in that case.
continue
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

Expand Down Expand Up @@ -374,19 +378,17 @@ def pre_forward(module, *args, **kwargs):
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)

handles_to_remove = hook_handles

if not self.disabled:
forward_handle = self.lora_model.model.register_forward_pre_hook(pre_forward, with_kwargs=True)
forward_handle = self.lora_model.model.register_forward_pre_hook(_pre_forward, with_kwargs=True)

# Run the forward pass: first the scaling pass in the hook, and then with the base model
yield

if not self.disabled:
# TODO(EricLBuehler): If we get a forward exception, we may have multiple forward hooks.
for handle in handles_to_remove:
handle.remove()
forward_handle.remove()
try:
yield
finally:
if not self.disabled:
for handle in hook_handles:
handle.remove()
forward_handle.remove()

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
Expand Down
13 changes: 13 additions & 0 deletions tests/test_xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,19 @@ def test_functional(self, tokenizer, model):
)
assert torch.isfinite(outputs[: inputs.shape[1] :]).all()

def test_forward_hooks_are_cleaned_up(self, tokenizer, model):
# There was an issue that forward hooks would accumulate during generation, since one hook per forward step was
# being registered and generate would call forward multiple times. This is already undesirable, but to make it
# worse, only the last hook was removed, resulting in hooks accumulating.
# See https://github.com/huggingface/peft/issues/1472#issuecomment-3235817807
inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt")
model.generate(input_ids=inputs.to(self.torch_device), max_new_tokens=10)
num_hooks_gen1 = len(model.base_model.model.model.decoder.layers[0].self_attn.k_proj._forward_pre_hooks)

model.generate(input_ids=inputs.to(self.torch_device), max_new_tokens=10)
num_hooks_gen2 = len(model.base_model.model.model.decoder.layers[0].self_attn.k_proj._forward_pre_hooks)
assert num_hooks_gen1 == num_hooks_gen2 == 0

def test_scalings_logging_methods(self, tokenizer, model):
model.enable_scalings_logging()

Expand Down