diff --git a/src/peft/tuners/prompt_tuning/config.py b/src/peft/tuners/prompt_tuning/config.py index ee8ceb8d89..b41669efe8 100644 --- a/src/peft/tuners/prompt_tuning/config.py +++ b/src/peft/tuners/prompt_tuning/config.py @@ -22,6 +22,7 @@ class PromptTuningInit(str, enum.Enum): TEXT = "TEXT" + SAMPLE_VOCAB = "SAMPLE_VOCAB" RANDOM = "RANDOM" @@ -31,7 +32,10 @@ class PromptTuningConfig(PromptLearningConfig): This is the configuration class to store the configuration of a [`PromptEmbedding`]. Args: - prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. + prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): + The initialization of the prompt embedding. `TEXT` will initialize with your text. `SAMPLE_VOCAB` will + initialize with randomly sampled tokens from the model's vocabulary. `RANDOM` will initialize with randomly + sampled continuous, soft tokens (warning: sampled soft tokens may fall outside of embedding manifold) prompt_tuning_init_text (`str`, *optional*): The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`. tokenizer_name_or_path (`str`, *optional*): diff --git a/src/peft/tuners/prompt_tuning/model.py b/src/peft/tuners/prompt_tuning/model.py index ce9b6bc409..9852ea28b4 100644 --- a/src/peft/tuners/prompt_tuning/model.py +++ b/src/peft/tuners/prompt_tuning/model.py @@ -64,7 +64,18 @@ def __init__(self, config, word_embeddings): total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) - if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: + if config.prompt_tuning_init == PromptTuningInit.SAMPLE_VOCAB and not config.inference_mode: + # Randomly sample tokens from the tokenizer's vocab + vocab_size = word_embeddings.num_embeddings + init_token_ids = torch.randint(0, vocab_size, (total_virtual_tokens,), dtype=torch.long).to( + word_embeddings.weight.device + ) + with gather_params_ctx(word_embeddings.parameters()): + word_embedding_weights = word_embeddings(init_token_ids).detach().clone() + word_embedding_weights = word_embedding_weights.to(torch.float32) + self.embedding.weight = torch.nn.Parameter(word_embedding_weights) + + elif config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: from transformers import AutoTokenizer tokenizer_kwargs = config.tokenizer_kwargs or {} diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 5c01e29052..12b4a62c3b 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -380,6 +380,18 @@ def mock_autotokenizer_from_pretrained(*args, **kwargs): expected_call = call(model_id, trust_remote_code=True, foo="bar") assert mock.call_args == expected_call + @pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_prompt_tuning_sample_vocab_prepare_for_training(self, model_id, config_cls, config_kwargs): + if config_cls != PromptTuningConfig: + pytest.skip(f"This test does not apply to {config_cls}") + + config_kwargs = config_kwargs.copy() + config_kwargs["prompt_tuning_init"] = PromptTuningInit.SAMPLE_VOCAB + config_kwargs["tokenizer_name_or_path"] = model_id + + self._test_prepare_for_training(model_id, config_cls, config_kwargs.copy()) + def test_prompt_tuning_config_invalid_args(self): # Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no # function in that case