这是indexloc提供的服务,不要输入任何密码
Skip to content

Lora PISSA init: not support gpt2 #2103

@suyang160

Description

@suyang160

System Info

peft 0.13.0
transformers 4.44.2
torch 2.4.0
Python 3.12.4

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

import os
os.environ["WANDB_DISABLED"] = "true"
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from torch.utils.data import Dataset
from torchsummary import summary
import torch
from datasets import load_dataset, config
from trl import SFTTrainer

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0,
target_modules=["attn.c_attn"],
init_lora_weights="pissa",
fan_in_fan_out=True,
bias="none"
)

model = get_peft_model(model, lora_config)

dataset = load_dataset("imdb", split="train[:1%]")

trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=128,
tokenizer=tokenizer,
)

Expected behavior

Hello, I found that current pissa init code forget to consider the fin_in_fin_out parameter to transpose the matrix weight, which makes gpt2 training failed because of dimension mismatch. I have fixed the bug with the following code:

def pissa_init(self, adapter_name, init_lora_weights):
    weight = self.get_base_layer().weight
    dtype = weight.dtype
    if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
        raise TypeError(
            "Please initialize PiSSA under float32, float16, or bfloat16. "
            "Subsequently, re-quantize the residual model to help minimize quantization errors."
        )
    weight = transpose(weight.to(torch.float32),self.fan_in_fan_out)
    if init_lora_weights == "pissa":
        # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
        V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
        Vr = V[:, : self.r[adapter_name]]
        Sr = S[: self.r[adapter_name]]
        Sr /= self.scaling[adapter_name]
        Uhr = Uh[: self.r[adapter_name]]
    elif len(init_lora_weights.split("_niter_")) == 2:
        Vr, Sr, Ur = svd_lowrank(
            weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
        )
        Sr /= self.scaling[adapter_name]
        Uhr = Ur.t()
    else:
        raise ValueError(
            f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
        )

    lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
    lora_B = Vr @ torch.diag(torch.sqrt(Sr))
    self.lora_A[adapter_name].weight.data = lora_A
    self.lora_B[adapter_name].weight.data = lora_B
    weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
    weight = transpose(weight.to(dtype),self.fan_in_fan_out)
    self.get_base_layer().weight.data = weight

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions