diff --git a/src/peft/tuners/xlora/layer.py b/src/peft/tuners/xlora/layer.py index d66d5bcaad..bf2afcd589 100644 --- a/src/peft/tuners/xlora/layer.py +++ b/src/peft/tuners/xlora/layer.py @@ -73,10 +73,12 @@ def get_maybe_topk_scalings(self, scalings) -> torch.Tensor: xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype) + # Apply per-token normalization to the xLoRA scaling factors using a softmax if self.config.enable_softmax_topk: nonzero_mask = xlora_scalings != 0 - softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1) - xlora_scalings[nonzero_mask] = softmax_res_nonzero + full = xlora_scalings.masked_fill(~nonzero_mask, float("-inf")) + new_scalings = torch.softmax(full, dim=-1) + xlora_scalings = new_scalings.masked_fill(~nonzero_mask, 0.0) return xlora_scalings diff --git a/src/peft/tuners/xlora/model.py b/src/peft/tuners/xlora/model.py index 6d36aa01d7..25e0902bfd 100644 --- a/src/peft/tuners/xlora/model.py +++ b/src/peft/tuners/xlora/model.py @@ -368,6 +368,8 @@ def _pre_forward(module, *args, **kwargs): self.lora_model.enable_adapter_layers() xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real) + # Store computed scalings to fix get_latest_scalings() returning None + self.internal_xlora_scalings = xlora_scalings # =========================== Real forward pass with calculated scalings ================== diff --git a/tests/test_xlora.py b/tests/test_xlora.py index 4a681ef2e9..724e7782cb 100644 --- a/tests/test_xlora.py +++ b/tests/test_xlora.py @@ -23,6 +23,7 @@ from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model from peft.peft_model import PeftModel +from peft.tuners.xlora.layer import XLoraLayer from peft.utils import infer_device @@ -381,3 +382,45 @@ def test_xlora_loading_valid(self): w1 = sd["base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.weight"] assert torch.allclose(w0, w1) + + def test_scalings_storage(self, tokenizer, model): + model.enable_scalings_logging() + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=10, + ) + + latest_scalings = model.get_latest_scalings() + assert latest_scalings is not None, "get_latest_scalings() should not return None after generation" + assert isinstance(latest_scalings, torch.Tensor) + assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values" + + def test_per_token_normalization_with_softmax_topk(self, tokenizer, model, monkeypatch): + model.internal_xlora_classifier.config.top_k_lora = 2 + model.internal_xlora_classifier.config.enable_softmax = False + model.internal_xlora_classifier.config.enable_softmax_topk = True + + captured_data = [] + orig_get_maybe_topk_scalings = XLoraLayer.get_maybe_topk_scalings + + def mock_get_maybe_topk_scalings(self, scalings): + result = orig_get_maybe_topk_scalings(self, scalings) + if getattr(model, "internal_xlora_scalings", None) is not None: + captured_data.append(result) + return result + + monkeypatch.setattr(XLoraLayer, "get_maybe_topk_scalings", mock_get_maybe_topk_scalings) + + model.enable_scalings_logging() + inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=1, + ) + + for scaling in captured_data: + weight_sums = scaling.sum(dim=-1) + assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), ( + "Per-token scaling weights are not normalized to sum to 1." + )