这是indexloc提供的服务,不要输入任何密码
Skip to content
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@
title: C3A
- local: package_reference/miss
title: MiSS
- local: package_reference/road
title: RoAd

title: Adapters
- sections:
Expand Down
31 changes: 31 additions & 0 deletions docs/source/package_reference/road.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# RoAd

[RoAd](https://arxiv.org/pdf/2409.00119) is a parameter‑efficient fine‑tuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions. RoAd achieves competitive or superior performance compared to other PEFT methods with under 0.1% trainable parameters. Unlike LoRA’s batched low‑rank updates, RoAd’s sparse rotations reformulate to simple element‑wise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch, i.e. serving multiple adapters simulatenously. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, interpreting its sparse 2D rotations as task-specific interventions within learned subspaces of hidden representations. These orthogonal subspaces can be composed to merge multiple task-specific behaviors—like multilingual capabilities or instruction following—without additional fine-tuning, enabling modular, interpretable adaptations in LLMs.

Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. Currently RoAd only supports linear layers and it can be used on models quantized with bitsandbytes (4-bit or 8-bit).

For running inference with different RoAd adapters in the same batch see [Inference with different LoRA adapters in the same batch](../developer_guides/lora#inference-with-different-lora-adapters-in-the-same-batch).

## RoadConfig

[[autodoc]] tuners.road.config.RoadConfig

## RoadModel

[[autodoc]] tuners.road.model.RoadModel
88 changes: 88 additions & 0 deletions examples/road_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# RoAd: 3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability


## Introduction

[RoAd](https://arxiv.org/pdf/2409.00119) is a novel method that adapts LLMs using simple 2D rotations. It is highly parameter-efficient,
achieving strong performance with less than 0.1% trainable parameters.
RoAd also supports efficient serving of mixed-adapter requests within a batch, incurring only element-wise computation overhead rather than costly batch matrix multiplications.
Additionally, it improves model interpretability through structured and composable transformations.

## Quick start
```python
import torch
from peft import RoadConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
road_config = RoadConfig(
variant="1",
)
peft_model = get_peft_model(model, road_config)
trainer = transformers.Trainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("road-llama-3-8b")
```

RoAd requires a higher learning rate compared to LoRa and similar approaches, set it to around 1e-3.

Run the finetuning script simply by running:

```bash
python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco
```

RoAd also supports quantization. To use 4-bit quantization try:

```bash
python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --quantize
```

### Full example of the script
```bash
python road_finetuning.py \
--base_model "PATH_TO_MODEL" \
--data_path "PATH_TO_DATASET" \
--output_dir "PATH_TO_OUTPUT_DIR" \
--batch_size 1 \
--num_epochs 3 \
--learning_rate 1e-3 \
--cutoff_len 512 \
--val_set_size 500 \
--quantize \
--eval_step 10 \
--save_step 100 \
--device "cuda:0" \
--variant 1 \
--road_target_modules "q_proj,k_proj,v_proj,o_proj" \
--hub_model_id "YOUR_HF_REPO" \
--push_to_hub
```
## Use the model on 🤗
You can load and use the model as any other 🤗 models.
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("ppetrushkov/llama-2-7b-sql-road-test")
```


## Citation
```
@inproceedings{
liao2024in,
title={3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability},
author={Baohao Liao and Christof Monz},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=rYjYwuM6yH}
}
```
203 changes: 203 additions & 0 deletions examples/road_finetuning/road_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import RoadConfig, get_peft_model, prepare_model_for_kbit_training


def train_model(
base_model: str,
data_path: str,
output_dir: str,
batch_size: int,
num_epochs: int,
learning_rate: float,
cutoff_len: int,
val_set_size: int,
quantize: bool,
eval_step: int,
save_step: int,
device: str,
variant: str,
road_target_modules: str,
hub_model_id: str,
push_to_hub: bool,
):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hf_token = os.getenv("HF_TOKEN")

# Setup device
device = torch.device(device)
print(f"Using device: {device}")

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)

# IF YOU WANNA QUANTIZE THE MODEL
if quantize:
model = AutoModelForCausalLM.from_pretrained(
base_model,
token=hf_token,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
),
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
# setup for quantized training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token, device_map="auto")
# RoAd config for the PEFT model
road_config = RoadConfig(
variant=variant, # Rank of matrix
target_modules=(
road_target_modules.split(",")
if road_target_modules
else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
),
)

# get the peft model with RoAd config
model = get_peft_model(model, road_config)

model.to(device) # MODEL TO GPU/CUDA
tokenizer.pad_token = tokenizer.eos_token

# Load the dataset
dataset = load_dataset(data_path)

def tokenize_function(examples):
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len)
inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task
return inputs

# Tokenize the dataset and prepare for training
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)

# Data collator to dynamically pad the batched examples
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# Define training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=100,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=eval_step,
save_steps=save_step,
save_total_limit=2,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
gradient_accumulation_steps=16,
fp16=True,
learning_rate=learning_rate,
hub_token=hf_token,
)

# Clear CUDA cache to free memory
torch.cuda.empty_cache()

# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
data_collator=data_collator,
)

# Start model training
trainer.train()

# Save and push the trained model and tokenizer
if push_to_hub:
# Push the main model to the hub
trainer.push_to_hub(commit_message="Fine-tuned model")

# Save the model and tokenizer locally
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DoRA and PEFT")
parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name")
parser.add_argument(
"--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name"
)
parser.add_argument(
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate")
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
parser.add_argument("--quantize", action="store_true", help="Use quantization")
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training")
parser.add_argument(
"--variant", type=str, default="road_1", choices=["road_1", "road_2", "road_4"], help="RoAD variant"
)
parser.add_argument(
"--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for RoAd"
)
parser.add_argument(
"--hub_model_id",
type=str,
default="path/to/repo",
help="Repository name to push the model on the Hugging Face Hub",
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
args = parser.parse_args()
train_model(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
quantize=args.quantize,
eval_step=args.eval_step,
save_step=args.save_step,
device=args.device,
variant=args.variant,
road_target_modules=args.road_target_modules,
hub_model_id=args.hub_model_id,
push_to_hub=args.push_to_hub,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"auto_mapping": null,
"base_model_name_or_path": null,
"group_size": 64,
"inference_mode": false,
"init_weights": true,
"peft_type": "ROAD",
"revision": null,
"target_modules": null,
"task_type": null,
"variant": "road_2"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"optimizer_kwargs": {
"lr": 1e-3
}
}
4 changes: 4 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
PromptTuningInit,
RandLoraConfig,
RandLoraModel,
RoadConfig,
RoadModel,
ShiraConfig,
ShiraModel,
TrainableTokensConfig,
Expand Down Expand Up @@ -194,6 +196,8 @@
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"RoadConfig",
"RoadModel",
"ShiraConfig",
"ShiraModel",
"TaskType",
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .randlora import RandLoraConfig, RandLoraModel
from .road import RoadConfig, RoadModel
from .shira import ShiraConfig, ShiraModel
from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel
from .vblora import VBLoRAConfig, VBLoRAModel
Expand Down Expand Up @@ -99,6 +100,8 @@
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"RoadConfig",
"RoadModel",
"ShiraConfig",
"ShiraModel",
"TrainableTokensConfig",
Expand Down
Loading
Loading