这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
66d5b14
shira_implementation
hteague-qti Jun 11, 2025
67c6c7f
Remove an unnecessary file
kkb-code Jun 11, 2025
ca7d7d7
Fix the help text in SHiRA config
kkb-code Jun 11, 2025
7cbc3ec
Add unit tests and minor bug fix
kkb-code Jun 13, 2025
90d4f54
Add SHiRA support into test_custom_models.py
kkb-code Jun 16, 2025
9345d87
Add or correct copyright notices to all new SHiRA files
kkb-code Jun 16, 2025
86276fa
Remove __getattr__ function from SHiRA __init__.py file
kkb-code Jun 16, 2025
327f1ad
Remove merged property from ShiraLayer class because it is already in…
kkb-code Jun 16, 2025
a448963
Clean up some unused variables in ShiraLayer
kkb-code Jun 16, 2025
d849bca
Remove kwargs line in SHiRA mask_functions.py
kkb-code Jun 18, 2025
ed7b2e0
Default random seed set to None for SHiRA random mask
kkb-code Jun 18, 2025
97a8a16
Lint fixes
kkb-code Jun 18, 2025
d93f8b3
Fix ShiraConfig and mask_fn implementation to have Literals and a sep…
kkb-code Jun 24, 2025
cc13e33
Inherit ShiraLayer from nn.Module and a minor typo fix in config
kkb-code Jun 24, 2025
c363452
Clean up dead code (shira_embedding and to_dict and other comments)
kkb-code Jun 24, 2025
0ab4155
Minor bugfix in ShiraLayer
kkb-code Jun 24, 2025
6a46c1d
Fix docstring in SHiRA mask_functions.py
kkb-code Jun 24, 2025
b272653
Add remaining SHiRA tests
kkb-code Jun 25, 2025
64a1d54
Attempt the new forward implementation for SHiRA. Left commented beca…
kkb-code Jun 25, 2025
9ad780f
Lint fixes for SHiRA
kkb-code Jun 25, 2025
0a85520
Add SHiRA example and fix some docstrings in config.
kkb-code Jun 25, 2025
6c3fb5b
Fix readme for SHiRA examples
kkb-code Jun 25, 2025
e48cac5
Add SHiRA docs
kkb-code Jun 26, 2025
5867f83
Minor Lint fixes for SHiRA examples code
kkb-code Jun 26, 2025
8754bf4
Merge branch 'main' into shira_adapters
kkb-code Jun 26, 2025
eac45d5
Add init_randn_shira_weight argument to ShiraConfig so we can test no…
kkb-code Jun 26, 2025
8b95933
Minor Lint Fixes for SHiRA
kkb-code Jun 26, 2025
5aa3976
Fix the addmm_sparse_cuda error for bfloat and half for SHiRA and upd…
kkb-code Jun 28, 2025
e03fecb
Add the readme example into a file
kkb-code Jun 30, 2025
c977351
Make init_weights in SHiRA compatible with the rest of PEFT.
kkb-code Jul 1, 2025
8e1092a
Clean up and fix docstrings in SHiRA
kkb-code Jul 1, 2025
58838c9
Minor Lint Fixes for SHiRA
kkb-code Jul 1, 2025
e009677
Revert back to the original forward pass for SHiRA.
kkb-code Jul 1, 2025
a65194a
Merge main into SHiRA pull request
kkb-code Jul 2, 2025
104441a
Add SHiRA benchmark configs
kkb-code Jul 2, 2025
16a5f64
Minor Lint fixes for SHiRA
kkb-code Jul 2, 2025
2749212
make style fixed for SHiRA config.py and mask_functions.py
kkb-code Jul 3, 2025
fa9afc8
Fix SHiRA benchmark experiment name
kkb-code Jul 3, 2025
9004557
Merge branch 'main' into shira_adapters
kkb-code Jul 3, 2025
73b26b5
Add a new test for custom mask functions and fix some bugs
kkb-code Jul 3, 2025
771ba24
Fix SHiRA test formatting in test_shira.py and remove no_autocast tes…
kkb-code Jul 3, 2025
6f08b3a
Improve the shira-indices loading part for from_pretrained method
kkb-code Jul 4, 2025
c77d736
Fix example for merge_and_unload in SHiRA model.py
kkb-code Jul 4, 2025
8c68a5a
Attempt to use BufferDict for shira_indices in layer.py. Fails becaus…
kkb-code Jul 8, 2025
f413af1
Revert back the shira_indices BufferDict attempt
kkb-code Jul 8, 2025
1ce7f95
Fix minor things in shira_finetuning and test_shira. Remove shira_sft…
kkb-code Jul 8, 2025
7f3a706
Minor Lint fixes for SHiRA
kkb-code Jul 8, 2025
a3b2351
Fix the warning issue in shira config for from_pretrained and a minor…
kkb-code Jul 8, 2025
41df090
Merge branch 'main' into shira_adapters
kkb-code Jul 8, 2025
56ef038
Windows test -- no shira in save_and_load. Revert later.
kkb-code Jul 10, 2025
34af1fd
Bring back all os for tests
kkb-code Jul 10, 2025
ee9d1d5
Bring back shira part inside save_and_load.py
kkb-code Jul 10, 2025
f4ff99d
Remove just the shira indices saving and loading part. Save the shira…
kkb-code Jul 10, 2025
c94333a
More debugging. Change SHiRA indices dtype to torch.float32 when savi…
kkb-code Jul 11, 2025
cc47b6a
Bring back all tests and cast shira_indices to float only for windows…
kkb-code Jul 11, 2025
c93b36f
Fix 120 char limit in shira save_and_load.py
kkb-code Jul 11, 2025
2014c43
Minor typos fixed.
kkb-code Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@
title: Trainable Tokens
- local: package_reference/randlora
title: RandLora
- local: package_reference/shira
title: SHiRA
- local: package_reference/c3a
title: C3A

Expand Down
35 changes: 35 additions & 0 deletions docs/source/package_reference/shira.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Sparse High Rank Adapters

Sparse High Rank Adapters or [SHiRA](https://arxiv.org/abs/2406.13175) is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.

SHiRA currently has the following constraint:

- Only `nn.Linear` layers are supported.

The abstract from the paper is:

> Low Rank Adaptation (LoRA) has gained massive attention in the recent generative AI research. One of the main advantages of LoRA is its ability to be fused with pretrained models, adding no overhead during inference. However, from a mobile deployment standpoint, we can either avoid inference overhead in the fused mode but lose the ability to switch adapters rapidly, or suffer significant (up to 30% higher) inference latency while enabling rapid switching in the unfused mode. LoRA also exhibits concept-loss when multiple adapters are used concurrently. In this paper, we propose Sparse High Rank Adapters (SHiRA), a new paradigm which incurs no inference overhead, enables rapid switching, and significantly reduces concept-loss. Specifically, SHiRA can be trained by directly tuning only 1-2% of the base model weights while leaving others unchanged. This results in a highly sparse adapter which can be switched directly in the fused mode. We further provide theoretical and empirical insights on how high sparsity in SHiRA can aid multi-adapter fusion by reducing concept loss. Our extensive experiments on LVMs and LLMs demonstrate that finetuning only a small fraction of the parameters in the base model significantly outperforms LoRA while enabling both rapid switching and multi-adapter fusion. Finally, we provide a latency- and memory-efficient SHiRA implementation based on Parameter-Efficient Finetuning (PEFT) Library which trains at nearly the same speed as LoRA while consuming up to 16% lower peak GPU memory, thus making SHiRA easy to adopt for practical use cases. To demonstrate rapid switching benefits during inference, we show that loading SHiRA on a base model can be 5x-16x faster than LoRA fusion on a CPU.

## ShiraConfig

[[autodoc]] tuners.shira.config.ShiraConfig

## ShiraModel

[[autodoc]] tuners.shira.model.ShiraModel
73 changes: 73 additions & 0 deletions examples/shira_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Sparse High Rank Adapters

## Introduction
Sparse High Rank Adapters or [SHiRA](https://arxiv.org/abs/2406.13175) is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.

## Quick start
```python
import torch
from peft import ShiraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
dataset = load_dataset("imdb", split="train[:1%]")
shira_config = ShiraConfig(
r=32,
)
peft_model = get_peft_model(model, shira_config)
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("shira-opt-350m")
```

For more options and a more detailed example code, you can refer to shira finetuning script.
Run the script simply by running:
```bash
python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m
```

If you want to run DDP by [accelerate](https://huggingface.co/docs/accelerate/en/index), please run `accelerate config` to set your ddp config, and run:
```bash
accelerate launch examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m
```
please add `--device_map cpu` if you want to run finetune on CPU.

If you want to train SHiRA with a custom sparse mask function which requires custom keyword arguments, please see the definition of `custom_random_mask_function_with_custom_kwargs` function provided in the `shira_fintuning.py` script. You can run this code using the `--use_custom_random_mask_function_with_custom_kwargs` argument. Without this argument, SHiRA defaults to a random sparse mask. Please run the code as follows. :
```bash
python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m --use_custom_random_mask_function_with_custom_kwargs

```


## Use the model
You can load and use the model as any other 🤗 PEFT model
```python
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
shira_model = PeftModel.from_pretrained(model, "shira-opt-350m")
```

## Citation
```
@inproceedings{NEURIPS2024_18c0102c,
author = {Bhardwaj, Kartikeya and Pandey, Nilesh Prasad and Priyadarshi, Sweta and Ganapathy, Viswanath and Kadambi, Shreya and Esteves, Rafael and Borse, Shubhankar and Whatmough, Paul and Garrepalli, Risheek and Van Baalen, Mart and Teague, Harris and Nagel, Markus},
booktitle = {Advances in Neural Information Processing Systems},
editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
pages = {13685--13715},
publisher = {Curran Associates, Inc.},
title = {Sparse High Rank Adapters},
url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/18c0102cb7f1a02c14f0929089b2e576-Paper-Conference.pdf},
volume = {37},
year = {2024}
}
```
217 changes: 217 additions & 0 deletions examples/shira_finetuning/shira_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from typing import Optional

import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

from peft import (
PeftModel,
ShiraConfig,
get_peft_model,
)


def train(
base_model: str = "path/to/model",
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "shira",
batch_size: int = 16,
num_epochs: int = 1,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 16,
eval_step: int = 100,
save_step: int = 100,
device_map: str = "auto",
shira_r: int = 32,
shira_target_modules: list[str] = None,
torch_dtype: str = "float16",
seed: Optional[int] = None,
use_custom_random_mask_function_with_custom_kwargs: Optional[bool] = False,
):
# Set device_map to the right place when enabling DDP.
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0))
if world_size > 1 and device_map != "cpu":
from accelerate import Accelerator

device_map = {"": Accelerator().process_index}
# Set seed
if seed is not None:
set_seed(seed)
model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": device_map}
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# For some tokenizer with no pad token like llama
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)

result["labels"] = result["input_ids"].copy()

return result

def generate_and_tokenize_prompt(example):
full_prompt = generate_prompt(example)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt

def custom_random_mask_function_with_custom_kwargs(custom_arg):
def mask_fn(base_layer, r):
"""
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the seed is derived from custom_kwargs.
Please use this as an example to create your own custom sparse masks that may use custom_kwargs. Remember, for a pretrained weight with shape m, n,
mask_fn must return only one mask (shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and dtype
of mask must be same as base layer's weight's device and dtype.
"""
new_seed = custom_arg
shape = base_layer.weight.shape
num_shira_weights = r * (shape[0] + shape[1])
random_generator = torch.Generator()
random_generator.manual_seed(new_seed)

idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
base_layer.weight.device
)
val = torch.ones_like(idx.type(base_layer.weight.dtype))
mask = torch.zeros_like(base_layer.weight.view(1, -1))
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)

return mask

return mask_fn

mask_type = "random" if not use_custom_random_mask_function_with_custom_kwargs else "custom"
config = ShiraConfig(
r=shira_r,
mask_type=mask_type,
target_modules=shira_target_modules,
task_type="CAUSAL_LM",
)
if use_custom_random_mask_function_with_custom_kwargs:
custom_arg = 120
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg)
config.mask_fn = custom_mask_fn

model = get_peft_model(model, config)

data = load_dataset(data_path)

train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=batch_size,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
logging_steps=100,
optim="adamw_torch",
eval_strategy="steps",
save_strategy="steps",
eval_steps=eval_step,
save_steps=save_step,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True,
ddp_find_unused_parameters=False if world_size > 1 else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
trainer.train()
model.save_pretrained(output_dir)

# Delete the model and load it again from the checkpoint.
del model
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
model = PeftModel.from_pretrained(model, output_dir)


def generate_prompt(example):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{example["instruction"]}
### Response:
{example["output"]}"""


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str, default="path/to/model")
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned")
parser.add_argument("--output_dir", type=str, default="shira")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--cutoff_len", type=int, default=256)
parser.add_argument("--val_set_size", type=int, default=16)
parser.add_argument("--eval_step", type=int, default=100)
parser.add_argument("--save_step", type=int, default=100)
parser.add_argument("--device_map", type=str, default="auto")
parser.add_argument("--shira_r", type=int, default=32)
parser.add_argument("--shira_target_modules", type=str, default=None)
parser.add_argument("--torch_dtype", type=str, default="float16")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--use_custom_random_mask_function_with_custom_kwargs", action="store_true")

args = parser.parse_args()

train(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
eval_step=args.eval_step,
save_step=args.save_step,
device_map=args.device_map,
shira_r=args.shira_r,
shira_target_modules=args.shira_target_modules,
torch_dtype=args.torch_dtype,
seed=args.seed,
use_custom_random_mask_function_with_custom_kwargs=args.use_custom_random_mask_function_with_custom_kwargs,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"auto_mapping": null,
"base_model_name_or_path": null,
"fan_in_fan_out": false,
"inference_mode": false,
"init_weights": true,
"mask_type": "random",
"modules_to_save": null,
"peft_type": "SHIRA",
"r": 32,
"random_seed": 42,
"revision": null,
"target_modules": null,
"task_type": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"optimizer_kwargs": {
"lr": 3e-4
}
}

4 changes: 4 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@
PromptTuningInit,
RandLoraConfig,
RandLoraModel,
ShiraConfig,
ShiraModel,
TrainableTokensConfig,
TrainableTokensModel,
VBLoRAConfig,
Expand Down Expand Up @@ -186,6 +188,8 @@
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"ShiraConfig",
"ShiraModel",
"TaskType",
"TrainableTokensConfig",
"TrainableTokensModel",
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .randlora import RandLoraConfig, RandLoraModel
from .shira import ShiraConfig, ShiraModel
from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel
from .vblora import VBLoRAConfig, VBLoRAModel
from .vera import VeraConfig, VeraModel
Expand Down Expand Up @@ -95,6 +96,8 @@
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"ShiraConfig",
"ShiraModel",
"TrainableTokensConfig",
"TrainableTokensModel",
"VBLoRAConfig",
Expand Down
Loading
Loading