diff --git a/src/peft/mapping_func.py b/src/peft/mapping_func.py index adcb55a8e5..fae671819e 100644 --- a/src/peft/mapping_func.py +++ b/src/peft/mapping_func.py @@ -25,7 +25,6 @@ from .mixed_model import PeftMixedModel from .peft_model import PeftModel from .tuners.tuners_utils import BaseTuner, BaseTunerLayer -from .utils import _prepare_prompt_learning_config def get_peft_model( @@ -120,8 +119,6 @@ def get_peft_model( low_cpu_mem_usage=low_cpu_mem_usage, ) - if peft_config.is_prompt_learning: - peft_config = _prepare_prompt_learning_config(peft_config, model_config) return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type]( model, peft_config, diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 07985e9b6f..b24858ebb9 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -1105,9 +1105,13 @@ def _prepare_prompt_learning_config(peft_config, model_config): peft_config.num_attention_heads = num_attention_heads # For grouped-query attention, see #1901. - if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config: + if (peft_config.peft_type == "PREFIX_TUNING") and ("num_key_value_heads" in model_config): num_key_value_heads = model_config["num_key_value_heads"] - peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads + if model_config.get("head_dim", None) is not None: + head_dim = model_config["head_dim"] + else: + head_dim = peft_config.token_dim // peft_config.num_attention_heads + peft_config.token_dim = head_dim * num_key_value_heads peft_config.num_attention_heads = num_key_value_heads if getattr(peft_config, "encoder_hidden_size", None) is None: diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 72925e2d18..cf5e882eca 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -665,15 +665,27 @@ def test_lora_layer_replication(self): self._test_prepare_for_training(model_id, LoraConfig, config_kwargs.copy()) self._test_generate(model_id, LoraConfig, config_kwargs.copy()) - def test_prompt_learning_with_grouped_query_attention(self): + def test_prefix_tuning_qwen2_with_grouped_query_attention(self): # See 1901, fixes a bug with handling GQA model_id = "peft-internal-testing/tiny-dummy-qwen2" - base_model = AutoModelForCausalLM.from_pretrained(model_id) - peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") - model = get_peft_model(base_model, peft_config) - x = torch.tensor([[1, 2, 3]]) - # does not raise - model(x) + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(base_model, peft_config) + x = torch.tensor([[1, 2, 3]]) + # does not raise + model(x) + + def test_prefix_tuning_qwen3_with_grouped_query_attention(self): + # See 2881, fixes a bug with handling GQA + model_id = "trl-internal-testing/tiny-Qwen3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(base_model, peft_config) + x = torch.tensor([[1, 2, 3]]) + # does not raise + model(x) def test_prefix_tuning_mistral(self): # See issue 869, 1962