From 4bde402ed01620861cdccca599f0c9dff518b2b0 Mon Sep 17 00:00:00 2001 From: Che-Xu Date: Sun, 21 Sep 2025 20:25:34 +0800 Subject: [PATCH 1/4] Store xlora scaling and fix per token normalization --- src/peft/tuners/xlora/layer.py | 7 +++++++ src/peft/tuners/xlora/model.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/src/peft/tuners/xlora/layer.py b/src/peft/tuners/xlora/layer.py index d66d5bcaad..fb8b94aba6 100644 --- a/src/peft/tuners/xlora/layer.py +++ b/src/peft/tuners/xlora/layer.py @@ -78,6 +78,13 @@ def get_maybe_topk_scalings(self, scalings) -> torch.Tensor: softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1) xlora_scalings[nonzero_mask] = softmax_res_nonzero + # Apply per-token normalization to the xLoRA scaling factors using a softmax + if self.config.enable_softmax_topk: + nonzero_mask = xlora_scalings != 0 + full = xlora_scalings.masked_fill(~nonzero_mask, float("-inf")) + new_scalings = torch.softmax(full, dim=-1) + xlora_scalings = new_scalings.masked_fill(~nonzero_mask, 0.0) + return xlora_scalings diff --git a/src/peft/tuners/xlora/model.py b/src/peft/tuners/xlora/model.py index 6d36aa01d7..25e0902bfd 100644 --- a/src/peft/tuners/xlora/model.py +++ b/src/peft/tuners/xlora/model.py @@ -368,6 +368,8 @@ def _pre_forward(module, *args, **kwargs): self.lora_model.enable_adapter_layers() xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real) + # Store computed scalings to fix get_latest_scalings() returning None + self.internal_xlora_scalings = xlora_scalings # =========================== Real forward pass with calculated scalings ================== From af982fe6a475c8102ac3c7f2e6472972d7f93f9d Mon Sep 17 00:00:00 2001 From: Che-Xu Date: Sun, 21 Sep 2025 20:48:18 +0800 Subject: [PATCH 2/4] Store xlora scaling and fix per token normalization --- src/peft/tuners/xlora/layer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/peft/tuners/xlora/layer.py b/src/peft/tuners/xlora/layer.py index fb8b94aba6..bf2afcd589 100644 --- a/src/peft/tuners/xlora/layer.py +++ b/src/peft/tuners/xlora/layer.py @@ -73,11 +73,6 @@ def get_maybe_topk_scalings(self, scalings) -> torch.Tensor: xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype) - if self.config.enable_softmax_topk: - nonzero_mask = xlora_scalings != 0 - softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1) - xlora_scalings[nonzero_mask] = softmax_res_nonzero - # Apply per-token normalization to the xLoRA scaling factors using a softmax if self.config.enable_softmax_topk: nonzero_mask = xlora_scalings != 0 From 95ca0b845c5520f8f5a8b4644496c708f3123965 Mon Sep 17 00:00:00 2001 From: Che-Xu Date: Wed, 8 Oct 2025 19:51:31 +0800 Subject: [PATCH 3/4] Add unit tests to check for the bugs --- tests/test_xlora.py | 124 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/tests/test_xlora.py b/tests/test_xlora.py index 4a681ef2e9..ad94b7b357 100644 --- a/tests/test_xlora.py +++ b/tests/test_xlora.py @@ -16,6 +16,7 @@ from functools import wraps import huggingface_hub +import numpy as np import pytest import torch from safetensors.torch import load_file @@ -381,3 +382,126 @@ def test_xlora_loading_valid(self): w1 = sd["base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.weight"] assert torch.allclose(w0, w1) + + def test_scalings_storage(self, tokenizer, model): + model.enable_scalings_logging() + inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=10, + ) + + latest_scalings = model.get_latest_scalings() + assert latest_scalings is not None, "get_latest_scalings() should not return None after generation" + assert isinstance(latest_scalings, torch.Tensor) + assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values" + + def test_per_token_normalization_with_softmax_topk(self, tokenizer, model): + captured_data = [] + + model.internal_xlora_classifier.config.top_k_lora = 2 + model.internal_xlora_classifier.config.enable_softmax = False + model.internal_xlora_classifier.config.enable_softmax_topk = True + + monkeypatches = [] + + def wrap_decoder_layers(): + wrapped_count = 0 + + decoder_layers = model.base_model.lora_model.model.model.decoder.layers + for layer_idx, decoder_layer in enumerate(decoder_layers): + attention_layers = [ + ("q_proj", decoder_layer.self_attn.q_proj), + ("k_proj", decoder_layer.self_attn.k_proj), + ("v_proj", decoder_layer.self_attn.v_proj), + ] + + for proj_name, proj_layer in attention_layers: + if ( + hasattr(proj_layer, "forward") + and hasattr(proj_layer.forward, "__self__") + and type(proj_layer.forward.__self__).__name__ == "XLoraLinearLayer" + ): + xlora_wrapper = proj_layer.forward.__self__ + original_forward = proj_layer.forward + + def create_wrapped_forward(orig_forward, layer_idx, name, xlora_wrapper): + def wrapped_forward(*args, **kwargs): + if ( + hasattr(model, "internal_xlora_scalings") + and model.internal_xlora_scalings is not None + ): + result = orig_forward(*args, **kwargs) + + scalings = kwargs.get("scalings", None) + normalized_scalings = None + if scalings is not None and hasattr(xlora_wrapper, "get_maybe_topk_scalings"): + normalized_scalings = xlora_wrapper.get_maybe_topk_scalings(scalings) + + capture_info = { + "layer": layer_idx, + "projection": name, + "result_shape": result.shape if hasattr(result, "shape") else "unknown", + "timestamp": len(captured_data), + "normalized_scalings": normalized_scalings, + } + captured_data.append(capture_info) + return result + else: + return orig_forward(*args, **kwargs) + + return wrapped_forward + + wrapper = create_wrapped_forward(original_forward, layer_idx, proj_name, xlora_wrapper) + + mp = pytest.MonkeyPatch() + mp.setattr(proj_layer, "forward", wrapper) + monkeypatches.append(mp) + wrapped_count += 1 + + return wrapped_count + + total_wrapped = wrap_decoder_layers() + assert total_wrapped > 0, "No X-LoRA layers were wrapped for testing." + + try: + model.enable_scalings_logging() + inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=1, + ) + + assert len(captured_data) > 0, "No scaling data was captured during forward pass." + + for data in captured_data: + normalized_scalings = data.get("normalized_scalings") + if normalized_scalings is None: + assert normalized_scalings is not None, ( + f"Missing normalized_scalings in layer {data['layer']} {data['projection']}" + ) + continue + + if hasattr(normalized_scalings, "cpu"): + scalings_np = normalized_scalings.cpu().detach().numpy() + else: + scalings_np = normalized_scalings + + # expected shape: (batch, seq_len, num_loras) or (batch, seq_len, top_k) + assert scalings_np.ndim == 3, ( + f"Unexpected scalings shape {scalings_np.shape} in layer {data['layer']} {data['projection']}" + ) + + batch_size, seq_len, num_experts = scalings_np.shape + for b in range(batch_size): + for t in range(seq_len): + weights = scalings_np[b, t, :] + weight_sum = weights.sum() + assert np.isclose(weight_sum, 1.0, atol=1e-5), ( + f"Per-token scaling not normalized in layer {data['layer']} {data['projection']}, " + f"batch={b}, token={t}: sum={weight_sum:.6f}, weights={weights}" + ) + + finally: + for mp in monkeypatches: + mp.undo() From ae5167f895579fc91d107104ba8d028d920a952b Mon Sep 17 00:00:00 2001 From: Che-Xu Date: Thu, 9 Oct 2025 11:51:05 +0800 Subject: [PATCH 4/4] Simplify the test_per_token_normalization_with_softmax_topk --- tests/test_xlora.py | 125 ++++++++------------------------------------ 1 file changed, 22 insertions(+), 103 deletions(-) diff --git a/tests/test_xlora.py b/tests/test_xlora.py index ad94b7b357..724e7782cb 100644 --- a/tests/test_xlora.py +++ b/tests/test_xlora.py @@ -16,7 +16,6 @@ from functools import wraps import huggingface_hub -import numpy as np import pytest import torch from safetensors.torch import load_file @@ -24,6 +23,7 @@ from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model from peft.peft_model import PeftModel +from peft.tuners.xlora.layer import XLoraLayer from peft.utils import infer_device @@ -396,112 +396,31 @@ def test_scalings_storage(self, tokenizer, model): assert isinstance(latest_scalings, torch.Tensor) assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values" - def test_per_token_normalization_with_softmax_topk(self, tokenizer, model): - captured_data = [] - + def test_per_token_normalization_with_softmax_topk(self, tokenizer, model, monkeypatch): model.internal_xlora_classifier.config.top_k_lora = 2 model.internal_xlora_classifier.config.enable_softmax = False model.internal_xlora_classifier.config.enable_softmax_topk = True - monkeypatches = [] - - def wrap_decoder_layers(): - wrapped_count = 0 - - decoder_layers = model.base_model.lora_model.model.model.decoder.layers - for layer_idx, decoder_layer in enumerate(decoder_layers): - attention_layers = [ - ("q_proj", decoder_layer.self_attn.q_proj), - ("k_proj", decoder_layer.self_attn.k_proj), - ("v_proj", decoder_layer.self_attn.v_proj), - ] - - for proj_name, proj_layer in attention_layers: - if ( - hasattr(proj_layer, "forward") - and hasattr(proj_layer.forward, "__self__") - and type(proj_layer.forward.__self__).__name__ == "XLoraLinearLayer" - ): - xlora_wrapper = proj_layer.forward.__self__ - original_forward = proj_layer.forward - - def create_wrapped_forward(orig_forward, layer_idx, name, xlora_wrapper): - def wrapped_forward(*args, **kwargs): - if ( - hasattr(model, "internal_xlora_scalings") - and model.internal_xlora_scalings is not None - ): - result = orig_forward(*args, **kwargs) - - scalings = kwargs.get("scalings", None) - normalized_scalings = None - if scalings is not None and hasattr(xlora_wrapper, "get_maybe_topk_scalings"): - normalized_scalings = xlora_wrapper.get_maybe_topk_scalings(scalings) - - capture_info = { - "layer": layer_idx, - "projection": name, - "result_shape": result.shape if hasattr(result, "shape") else "unknown", - "timestamp": len(captured_data), - "normalized_scalings": normalized_scalings, - } - captured_data.append(capture_info) - return result - else: - return orig_forward(*args, **kwargs) - - return wrapped_forward - - wrapper = create_wrapped_forward(original_forward, layer_idx, proj_name, xlora_wrapper) - - mp = pytest.MonkeyPatch() - mp.setattr(proj_layer, "forward", wrapper) - monkeypatches.append(mp) - wrapped_count += 1 - - return wrapped_count - - total_wrapped = wrap_decoder_layers() - assert total_wrapped > 0, "No X-LoRA layers were wrapped for testing." - - try: - model.enable_scalings_logging() - inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt") - outputs = model.generate( - input_ids=inputs.to(self.torch_device), - max_new_tokens=1, - ) + captured_data = [] + orig_get_maybe_topk_scalings = XLoraLayer.get_maybe_topk_scalings - assert len(captured_data) > 0, "No scaling data was captured during forward pass." + def mock_get_maybe_topk_scalings(self, scalings): + result = orig_get_maybe_topk_scalings(self, scalings) + if getattr(model, "internal_xlora_scalings", None) is not None: + captured_data.append(result) + return result - for data in captured_data: - normalized_scalings = data.get("normalized_scalings") - if normalized_scalings is None: - assert normalized_scalings is not None, ( - f"Missing normalized_scalings in layer {data['layer']} {data['projection']}" - ) - continue + monkeypatch.setattr(XLoraLayer, "get_maybe_topk_scalings", mock_get_maybe_topk_scalings) - if hasattr(normalized_scalings, "cpu"): - scalings_np = normalized_scalings.cpu().detach().numpy() - else: - scalings_np = normalized_scalings - - # expected shape: (batch, seq_len, num_loras) or (batch, seq_len, top_k) - assert scalings_np.ndim == 3, ( - f"Unexpected scalings shape {scalings_np.shape} in layer {data['layer']} {data['projection']}" - ) - - batch_size, seq_len, num_experts = scalings_np.shape - for b in range(batch_size): - for t in range(seq_len): - weights = scalings_np[b, t, :] - weight_sum = weights.sum() - assert np.isclose(weight_sum, 1.0, atol=1e-5), ( - f"Per-token scaling not normalized in layer {data['layer']} {data['projection']}, " - f"batch={b}, token={t}: sum={weight_sum:.6f}, weights={weights}" - ) - - finally: - for mp in monkeypatches: - mp.undo() + model.enable_scalings_logging() + inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt") + outputs = model.generate( + input_ids=inputs.to(self.torch_device), + max_new_tokens=1, + ) + + for scaling in captured_data: + weight_sums = scaling.sum(dim=-1) + assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), ( + "Per-token scaling weights are not normalized to sum to 1." + )