这是indexloc提供的服务,不要输入任何密码
Skip to content

Two issues with [X-LoRA] Implementation: Scalings Logging and Top-K Softmax Normalization #2786

@Che-Xu

Description

@Che-Xu

Hello PEFT team,
I am writing to report two issues I encountered while using X-LoRA with the Qwen2-VL-7B model.

Issue 1: internal_xlora_scalings Not Properly Initialized in _enable_peft_forward_hooks()

Location: _enable_peft_forward_hooks() in src/peft/tuners/xlora/model.py

Problem: After calling enable_scalings_logging(), the subsequent call to get_latest_scalings() returns None for self.internal_xlora_scalings.

Suggested Fix:

I suspect the issue is that in the _enable_peft_forward_hooks() function, after the first forward pass computes the xlora_scalings via:

xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)

the value of xlora_scalings is not assigned to self.internal_xlora_scalings. I suggest adding:

self.internal_xlora_scalings = xlora_scalings

immediately after the above line.

Issue 2: Incorrect Probability Normalization in Token-Level Top-K Selection
Location: get_maybe_topk_scalings() in src/peft/tuners/xlora/layer.py
Problem: When using enable_softmax=False, enable_softmax_topk=True, and top_k_lora=1, the current implementation normalizes the expert probabilities such that the sum over all tokens is 1, rather than summing to 1 per token. Here is the current code:

if self.config.enable_softmax_topk:
    nonzero_mask = xlora_scalings != 0
    softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1)
    xlora_scalings[nonzero_mask] = softmax_res_nonzero

Suggested Fix:
I believe the intended behavior is that each token's top-k expert probabilities should sum to 1. A more appropriate implementation may be:

if self.config.enable_softmax_topk:
    full_scalings = torch.full_like(xlora_scalings, float('-inf'))
    full_scalings.scatter_(-1, topk_indices, scalings.gather(-1, topk_indices))
    xlora_scalings = torch.softmax(full_scalings, dim=-1)

After applying this change in my experiments, the behavior matched expectations.

Could you please confirm whether these are indeed issues and whether the proposed changes are appropriate? Thank you for your time and effort.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions