这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/peft/tuners/xlora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def get_maybe_topk_scalings(self, scalings) -> torch.Tensor:

xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype)

# Apply per-token normalization to the xLoRA scaling factors using a softmax
if self.config.enable_softmax_topk:
nonzero_mask = xlora_scalings != 0
softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1)
xlora_scalings[nonzero_mask] = softmax_res_nonzero
full = xlora_scalings.masked_fill(~nonzero_mask, float("-inf"))
new_scalings = torch.softmax(full, dim=-1)
xlora_scalings = new_scalings.masked_fill(~nonzero_mask, 0.0)

return xlora_scalings

Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/xlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def _pre_forward(module, *args, **kwargs):
self.lora_model.enable_adapter_layers()

xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)
# Store computed scalings to fix get_latest_scalings() returning None
self.internal_xlora_scalings = xlora_scalings

# =========================== Real forward pass with calculated scalings ==================

Expand Down
43 changes: 43 additions & 0 deletions tests/test_xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model
from peft.peft_model import PeftModel
from peft.tuners.xlora.layer import XLoraLayer
from peft.utils import infer_device


Expand Down Expand Up @@ -381,3 +382,45 @@ def test_xlora_loading_valid(self):
w1 = sd["base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.weight"]

assert torch.allclose(w0, w1)

def test_scalings_storage(self, tokenizer, model):
model.enable_scalings_logging()
inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(self.torch_device),
max_new_tokens=10,
)

latest_scalings = model.get_latest_scalings()
assert latest_scalings is not None, "get_latest_scalings() should not return None after generation"
assert isinstance(latest_scalings, torch.Tensor)
assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values"

def test_per_token_normalization_with_softmax_topk(self, tokenizer, model, monkeypatch):
model.internal_xlora_classifier.config.top_k_lora = 2
model.internal_xlora_classifier.config.enable_softmax = False
model.internal_xlora_classifier.config.enable_softmax_topk = True

captured_data = []
orig_get_maybe_topk_scalings = XLoraLayer.get_maybe_topk_scalings

def mock_get_maybe_topk_scalings(self, scalings):
result = orig_get_maybe_topk_scalings(self, scalings)
if getattr(model, "internal_xlora_scalings", None) is not None:
captured_data.append(result)
return result

monkeypatch.setattr(XLoraLayer, "get_maybe_topk_scalings", mock_get_maybe_topk_scalings)

model.enable_scalings_logging()
inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(self.torch_device),
max_new_tokens=1,
)

for scaling in captured_data:
weight_sums = scaling.sum(dim=-1)
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), (
"Per-token scaling weights are not normalized to sum to 1."
)
Loading