-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
Description
System Info
Operation system: Linux
Distributor ID: Ubuntu
Description: Ubuntu 22.04.5 LTS
Release: 22.04
Codename: jammy
Nvcc version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
Packages:
peft==0.17.1
transformers==4.57.1
Who can help?
No response
Reproduction
When I try to prefix tuning of Qwen3-0.6B, a code like this:
from peft import PrefixTuningConfig, get_peft_model, TaskType, PeftType
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
path = "Qwen/Qwen3-0.6B"
base = AutoModelForCausalLM.from_pretrained(path)
print("load base")
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type=TaskType.CAUSAL_LM)
print(peft_config)
model = get_peft_model(base, peft_config)
print("load peft model")
x = torch.tensor([[1, 2, 3]])
model(x)would raise an error
load base
PrefixTuningConfig(task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, inference_mode=False, num_virtual_tokens=10, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, modules_to_save=None, encoder_hidden_size=None, prefix_projection=False)
load peft model
Traceback (most recent call last):
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/pdb.py", line 1723, in main
pdb._runscript(mainpyfile)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/pdb.py", line 1583, in _runscript
self.run(statement)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/bdb.py", line 598, in run
exec(cmd, globals, locals)
File "<string>", line 1, in <module>
File "/data/wuli_error/Awesome-EmoTTS/new_path.py", line 17, in <module>
model(x)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/peft/peft_model.py", line 1891, in forward
return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/generic.py", line 918, in wrapper
output = func(self, *args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 480, in forward
outputs: BaseModelOutputWithPast = self.model(
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 410, in forward
hidden_states = decoder_layer(
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 260, in forward
hidden_states, _ = self.self_attn(
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 210, in forward
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/cache_utils.py", line 776, in update
keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
File "/data/wuli_error/miniconda3/envs/llm-pbe/lib/python3.10/site-packages/transformers/cache_utils.py", line 119, in update
self.keys = torch.cat([self.keys, key_states], dim=-2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 64 but got size 128 for tensor number 1 in the list.
Uncaught exception. Entering post mortem debugging
Running 'cont' or 'step' will restart the program
It can be noticed that the peft_config would set token_dim to 512, which is different from model hidden state 1024 (8 times 128 below).
-> self.keys = torch.cat([self.keys, key_states], dim=-2)
(Pdb) key_states.shape
torch.Size([1, 8, 3, 128])
(Pdb) self.keys.shape
torch.Size([1, 8, 10, 64])
I'm not sure but according to #1901, this problem should be fixed by then.
# For grouped-query attention, see #1901.
if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config:
num_key_value_heads = model_config["num_key_value_heads"]
peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads
peft_config.num_attention_heads = num_key_value_heads
Expected behavior
The dimension problems should not be raised.
Metadata
Metadata
Assignees
Labels
No labels