-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
System Info
peft 0.13.0
transformers 4.44.2
torch 2.4.0
Python 3.12.4
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder - My own task or dataset (give details below)
Reproduction
import os
os.environ["WANDB_DISABLED"] = "true"
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset
from torchsummary import summary
import torch
from datasets import load_dataset, config
from trl import SFTTrainer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0,
target_modules=["attn.c_attn"],
init_lora_weights="pissa",
fan_in_fan_out=True,
bias="none"
)
model = get_peft_model(model, lora_config)
dataset = load_dataset("imdb", split="train[:1%]")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=128,
tokenizer=tokenizer,
)
Expected behavior
Hello, I found that current pissa init code forget to consider the fin_in_fin_out parameter to transpose the matrix weight, which makes gpt2 training failed because of dimension mismatch. I have fixed the bug with the following code:
def pissa_init(self, adapter_name, init_lora_weights):
weight = self.get_base_layer().weight
dtype = weight.dtype
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError(
"Please initialize PiSSA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = transpose(weight.to(torch.float32),self.fan_in_fan_out)
if init_lora_weights == "pissa":
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Vr = V[:, : self.r[adapter_name]]
Sr = S[: self.r[adapter_name]]
Sr /= self.scaling[adapter_name]
Uhr = Uh[: self.r[adapter_name]]
elif len(init_lora_weights.split("_niter_")) == 2:
Vr, Sr, Ur = svd_lowrank(
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
)
Sr /= self.scaling[adapter_name]
Uhr = Ur.t()
else:
raise ValueError(
f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
)
lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
lora_B = Vr @ torch.diag(torch.sqrt(Sr))
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
weight = transpose(weight.to(dtype),self.fan_in_fan_out)
self.get_base_layer().weight.data = weight