这是indexloc提供的服务,不要输入任何密码
Skip to content

Conversation

@macmacmacmac
Copy link
Contributor

@macmacmacmac macmacmacmac commented Oct 7, 2025

TLDR: Implemented a new initialization option for PromptEmbedding that fixes issue with original RANDOM initialization causing initialized embeddings to fall outside of the model's vocabulary manifold.

I implemented a new initialization option called RANDOM_DISCRETE. Usage is the same as RANDOM

Usage

generation_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM, 
    prompt_tuning_init=PromptTuningInit.RANDOM_DISCRETE, #samples random token from vocab
    num_virtual_tokens=NUM_VIRTUAL_TOKENS, 
    tokenizer_name_or_path=model_name 
)

Currently, PromptEmbedding's random initialization uses torch.nn.Embedding to initialize its embeddings. However, because different models have different embedding spaces each with its own manifold of well-defined vocab embeddings, this naive initialization is highly unlikely to produce embeddings that land on the manifold, which leads to really poor learning empirically. From my testing naive random initialization reduces accuracy by almost a factor of 3.

To help visualize, here's a PCA of Llama 3.1 8b's vocab embeddings vs naively randomly initializing embeddings

Image

VS. if we initialized embeddings by say randomly sampling tokens instead (the suggested fix/feature).

Image

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for coming up with, and implementing, this idea for better initialization of prompt tuning. I have a few small comments, please check.

Moreover, let's add a unit test for this new option. There are some existing tests here:

@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_prompt_tuning_text_prepare_for_training(self, model_id, config_cls, config_kwargs):
if config_cls != PromptTuningConfig:
pytest.skip(f"This test does not apply to {config_cls}")
config_kwargs = config_kwargs.copy()
config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT
config_kwargs["prompt_tuning_init_text"] = "This is a test prompt."
config_kwargs["tokenizer_name_or_path"] = model_id
self._test_prepare_for_training(model_id, config_cls, config_kwargs.copy())
def test_prompt_tuning_text_tokenizer_kwargs(self):
# Allow users to pass additional arguments to Tokenizer.from_pretrained
# Fix for #1032
mock = Mock()
orig_from_pretrained = AutoTokenizer.from_pretrained
def mock_autotokenizer_from_pretrained(*args, **kwargs):
mock(*args, **kwargs)
return orig_from_pretrained(config.tokenizer_name_or_path)
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
config = PromptTuningConfig(
base_model_name_or_path=model_id,
tokenizer_name_or_path=model_id,
num_virtual_tokens=10,
prompt_tuning_init=PromptTuningInit.TEXT,
task_type="CAUSAL_LM",
prompt_tuning_init_text="This is a test prompt.",
tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"},
)
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
with patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained):
_ = get_peft_model(model, config)
expected_call = call(model_id, trust_remote_code=True, foo="bar")
assert mock.call_args == expected_call

Could you please create a new one below those? You can use test_prompt_tuning_text_prepare_for_training as a template.

Regarding the implementation itself, my understanding is that it will use random embedding vectors from the existing embedding matrix. This looks like a good idea for better initialization, but I'm curious if you have tested it in an experiment to compare the performance to random initialization?

Also, I wonder if, instead of completely copying a vector, we could copy random elements from the embedding size axis, so that the resulting vectors are unique. Maybe it makes no difference, just something that I was thinking about.

self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

elif config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode:
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed

Suggested change
import numpy as np

self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode:
if config.prompt_tuning_init == PromptTuningInit.RANDOM_DISCRETE and not config.inference_mode:
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using torch directly, instead of creating numpy arrays and later converting them to torch?

Args:
prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding.
prompt_tuning_init (Union[[`PromptTuningInit`], `str`]):
The initialization of the prompt embedding. `TEXT` will initialize with your text. `RANDOM_DISCRETE` will
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending the docstring. Since RANDOM_DISCRETE is not one of the options described in the paper, let's mention that. Otherwise, users might start searching the paper for that option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I believe the paper does actually mention this initialization scheme. I italicized the part RANDOM_DISCRETE corresponds to. RANDOM corresponds to the first option.

Though...yeah...RANDOM_DISCRETE maybe isn't the most intuitive name. I'm not familiar with naming schemes, so I would appreciate suggestions

2.1 Design Decisions
There are many possible ways to initialize the
prompt representations. The simplest is to train
from scratch, using random initialization. A more
sophisticated option is to initialize each prompt
token to an embedding drawn from the model’s
vocabulary.
Conceptually, our soft-prompt mod-
ulates the frozen network’s behavior in the same
way as text preceding the input, so it follows that
a word-like representation might serve as a good
initialization spot.

Copy link
Contributor Author

@macmacmacmac macmacmacmac Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The paper actually ran some experiments comparing the different initialization schemes--it seems like it performs consistently better than random initialization up to a certain model size (They called this 'sampled vocab' so I will change the name to reflect that). I would suggest making this the default initialization method, but I guess that might cause backward compatability issues.

image

@macmacmacmac
Copy link
Contributor Author

@BenjaminBossan

Moreover, let's add a unit test for this new option. There are some existing tests here:

I'll work on that! Thank you for the other feedback!

Regarding the implementation itself, my understanding is that it will use random embedding vectors from the existing embedding matrix. This looks like a good idea for better initialization, but I'm curious if you have tested it in an experiment to compare the performance to random initialization?

I did actually test this out. I ran a couple trials with my task experimenting between random, discrete random, some sensible prompt, and some unrelated prompt using Llama as my base model on a regression task. I was getting as low as 0.3 Pearson R for random initialization while other methods (including random discrete) were consistently >0.9 Pearson R.

From what I gathered, random initialization is producing vectors that fall outside of the part of the embedding space that is well-defined semantically to the model, so the model doesn't know how to interpret these embeddings, leading to worse performance. You can check this for yourself if you'd like by checking the average L2 norm of the random embeddings vs the model's vocab embeddings. For LLama, I found that it was around ~64x larger.

@macmacmacmac macmacmacmac force-pushed the added-new-init-prompt-embedding branch from 1accee7 to 68c1056 Compare October 9, 2025 00:58
@BenjaminBossan
Copy link
Member

Thanks for your feedback. Indeed, your method corresponds to the "sampled vocab" option. I agree that it should be renamed to make this connection clear. Also thanks for reporting your results. Intuitively, I find the idea of taking random embedding vectors a bit strange, but if it works empirically, I won't argue with that ;)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this nice addition, the PR LGTM.

@BenjaminBossan BenjaminBossan merged commit 2c29cf7 into huggingface:main Oct 9, 2025
5 of 13 checks passed
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Oct 9, 2025
A new initialization method was added to prompt tuning in huggingface#2815. This PR
adds an experiment config for this method to the MetaMathQA benchmark.

Testing locally, this got a test accuracy of 36%, compared to 25% with
random initialization.
BenjaminBossan added a commit that referenced this pull request Oct 13, 2025
A new initialization method was added to prompt tuning in #2815. This PR
adds an experiment config for this method to the MetaMathQA benchmark.

Testing locally, this got a test accuracy of 36%, compared to 25% with
random initialization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants