这是indexloc提供的服务,不要输入任何密码
Skip to content

Conversation

@aflueckiger
Copy link
Contributor

@aflueckiger aflueckiger commented Oct 23, 2025

It looks like that trainable_token_indices has been broken since #2605 for the lm_head in case the weights are not tied. The lm_head is an instance of Linear rather than Embedding, and, thus, it doesn't have an attribute embedding_dim.

The embedding dimension is not even needed further down unless when training with deepspeed or re-initializing the weights.

Asking @BenjaminBossan for a review due to the changes in #2605. I couldn't find any reports concerning that issue. Let me know if you would like to fix it differently.

Minimal config for reproduction:

model_name = "utter-project/EuroLLM-9B-Instruct"

trainable_tokens_indices = [5, 6, 7, 8, 9]

trainable_tokens = {"lm_head": trainable_tokens_indices, "embed_tokens": trainable_tokens_indices}

# Before this PR: Only training the embedding layer was possible
# trainable_tokens = {"embed_tokens": trainable_tokens_indices}
# trainable_tokens = trainable_tokens_indices

LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    modules_to_save=None,
    trainable_token_indices=trainable_tokens,
)

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing this fix.

For the purpose of testing, could you please give a small example where this is broken? The model ID and PEFT config should be enough.

The embedding dimension is not even needed further down unless when training with deepspeed or re-initializing the weights.

It's also used when init_weights=False. I guess we could move it into the corresponding branches to avoid retrieving this argument when it's not needed, but I think it's better to make the code robust enough that it does not break.

@aflueckiger
Copy link
Contributor Author

Thanks for the quick response. I added more information to reproduce the error in the PR description and replaced the try-expect block. Let me know about the preferred position for handling the attribute.

@BenjaminBossan
Copy link
Member

Thanks for providing the reproducer. I can confirm that it raises the error and it helped me narrow down why the existing tests didn't catch this. We have a gap where we don't test specifying embed_tokens and lm_head both. To fill the gap, could you please add a unit test? Here is a suggestion:

    def test_trainable_token_indices_targets_head_and_embedding(self):
        # targeting embedding and LM head explicitly, see #2863
        model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
        with hub_online_once(model_id):
            model = AutoModelForCausalLM.from_pretrained(model_id)
            config = LoraConfig(trainable_token_indices={"lm_head": [0], "embed_tokens": [0]})
            get_peft_model(model, config)  # does not raise

A good spot for this test would be here:

@aflueckiger
Copy link
Contributor Author

Thanks for providing the right spot for the test. It is added now.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this issue, LGTM.

@BenjaminBossan BenjaminBossan merged commit fff52ab into huggingface:main Oct 23, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants