这是indexloc提供的服务,不要输入任何密码
Skip to content

Bug in prefix tuning of Qwen3-0.6B: group-query attention fix in #1901 still cause error #2881

@papaya0481

Description

@papaya0481

System Info

Operation system: Linux

Distributor ID: Ubuntu
Description:    Ubuntu 22.04.5 LTS
Release:        22.04
Codename:       jammy

Nvcc version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0

Packages:

peft==0.17.1
transformers==4.57.1

Who can help?

No response

Reproduction

When I try to prefix tuning of Qwen3-0.6B, a code like this:

from peft import PrefixTuningConfig, get_peft_model, TaskType, PeftType
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

path = "Qwen/Qwen3-0.6B"
base = AutoModelForCausalLM.from_pretrained(path)

print("load base")

peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type=TaskType.CAUSAL_LM)

print(peft_config)
model = get_peft_model(base, peft_config)
print("load peft model")

x = torch.tensor([[1, 2, 3]])
model(x)

would raise an error

load base
PrefixTuningConfig(task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, inference_mode=False, num_virtual_tokens=10, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, modules_to_save=None, encoder_hidden_size=None, prefix_projection=False)
load peft model
Traceback (most recent call last):
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/pdb.py", line 1723, in main
    pdb._runscript(mainpyfile)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/pdb.py", line 1583, in _runscript
    self.run(statement)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/bdb.py", line 598, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File "/data/wuli_error/Awesome-EmoTTS/new_path.py", line 17, in <module>
    model(x)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/peft/peft_model.py", line 1891, in forward
    return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/generic.py", line 918, in wrapper
    output = func(self, *args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 480, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 410, in forward
    hidden_states = decoder_layer(
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 260, in forward
    hidden_states, _ = self.self_attn(
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 210, in forward
    key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/cache_utils.py", line 776, in update
    keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
  File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/cache_utils.py", line 119, in update
    self.keys = torch.cat([self.keys, key_states], dim=-2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 64 but got size 128 for tensor number 1 in the list.
Uncaught exception. Entering post mortem debugging
Running 'cont' or 'step' will restart the program

It can be noticed that the peft_config would set token_dim to 512, which is different from model hidden state 1024 (8 times 128 below).

-> self.keys = torch.cat([self.keys, key_states], dim=-2)
(Pdb) key_states.shape
torch.Size([1, 8, 3, 128])
(Pdb) self.keys.shape
torch.Size([1, 8, 10, 64])

I'm not sure but according to #1901, this problem should be fixed by then.

    # For grouped-query attention, see #1901.
    if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config:
        num_key_value_heads = model_config["num_key_value_heads"]
        peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads
        peft_config.num_attention_heads = num_key_value_heads

Expected behavior

The dimension problems should not be raised.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions