-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Support for RoAd: 2D Rotary Adaptation #2678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
95456bc
RoAD adapter implementation (https://arxiv.org/pdf/2409.00119)
ppetrushkov 2df411f
Cleanup and documentation
ppetrushkov 74ac723
Apply suggestions from code review
ppetrushkov 8067d04
Change road enum, extra checks, documentation.
ppetrushkov 6c894d9
Fix initialization and unit tests
ppetrushkov e93cc20
Merge branch 'main' into road
ppetrushkov 06a84b9
Apply suggestions from code review
ppetrushkov 070ca71
Add gpu tests for bnb road
ppetrushkov 319ab6d
Style
ppetrushkov 93b71bc
Style with doc-builder
ppetrushkov e4ce021
Add to toc
ppetrushkov eaa8ca5
Merge remote-tracking branch 'upstream/main' into road
ppetrushkov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| <!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
|
|
||
| # RoAd | ||
|
|
||
| [RoAd](https://arxiv.org/pdf/2409.00119) is a parameter‑efficient fine‑tuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions. RoAd achieves competitive or superior performance compared to other PEFT methods with under 0.1% trainable parameters. Unlike LoRA’s batched low‑rank updates, RoAd’s sparse rotations reformulate to simple element‑wise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch, i.e. serving multiple adapters simulatenously. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, interpreting its sparse 2D rotations as task-specific interventions within learned subspaces of hidden representations. These orthogonal subspaces can be composed to merge multiple task-specific behaviors—like multilingual capabilities or instruction following—without additional fine-tuning, enabling modular, interpretable adaptations in LLMs. | ||
|
|
||
| Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. Currently RoAd only supports linear layers and it can be used on models quantized with bitsandbytes (4-bit or 8-bit). | ||
|
|
||
| For running inference with different RoAd adapters in the same batch see [Inference with different LoRA adapters in the same batch](../developer_guides/lora#inference-with-different-lora-adapters-in-the-same-batch). | ||
|
|
||
| ## RoadConfig | ||
|
|
||
| [[autodoc]] tuners.road.config.RoadConfig | ||
|
|
||
| ## RoadModel | ||
|
|
||
| [[autodoc]] tuners.road.model.RoadModel | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # RoAd: 3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability | ||
|
|
||
|
|
||
| ## Introduction | ||
|
|
||
| [RoAd](https://arxiv.org/pdf/2409.00119) is a novel method that adapts LLMs using simple 2D rotations. It is highly parameter-efficient, | ||
| achieving strong performance with less than 0.1% trainable parameters. | ||
| RoAd also supports efficient serving of mixed-adapter requests within a batch, incurring only element-wise computation overhead rather than costly batch matrix multiplications. | ||
| Additionally, it improves model interpretability through structured and composable transformations. | ||
|
|
||
| ## Quick start | ||
| ```python | ||
| import torch | ||
| from peft import RoadConfig, get_peft_model | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer | ||
| from datasets import load_dataset | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda") | ||
| tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | ||
| dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") | ||
| road_config = RoadConfig( | ||
| variant="1", | ||
| ) | ||
| peft_model = get_peft_model(model, road_config) | ||
| trainer = transformers.Trainer( | ||
| model=peft_model, | ||
| train_dataset=dataset, | ||
| dataset_text_field="text", | ||
| max_seq_length=2048, | ||
| tokenizer=tokenizer, | ||
| ) | ||
| trainer.train() | ||
| peft_model.save_pretrained("road-llama-3-8b") | ||
| ``` | ||
|
|
||
| RoAd requires a higher learning rate compared to LoRa and similar approaches, set it to around 1e-3. | ||
|
|
||
| Run the finetuning script simply by running: | ||
|
|
||
| ```bash | ||
| python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco | ||
| ``` | ||
|
|
||
| RoAd also supports quantization. To use 4-bit quantization try: | ||
|
|
||
| ```bash | ||
| python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --quantize | ||
| ``` | ||
|
|
||
| ### Full example of the script | ||
| ```bash | ||
| python road_finetuning.py \ | ||
| --base_model "PATH_TO_MODEL" \ | ||
| --data_path "PATH_TO_DATASET" \ | ||
| --output_dir "PATH_TO_OUTPUT_DIR" \ | ||
| --batch_size 1 \ | ||
| --num_epochs 3 \ | ||
| --learning_rate 1e-3 \ | ||
| --cutoff_len 512 \ | ||
| --val_set_size 500 \ | ||
| --quantize \ | ||
| --eval_step 10 \ | ||
| --save_step 100 \ | ||
| --device "cuda:0" \ | ||
| --variant 1 \ | ||
| --road_target_modules "q_proj,k_proj,v_proj,o_proj" \ | ||
| --hub_model_id "YOUR_HF_REPO" \ | ||
| --push_to_hub | ||
| ``` | ||
| ## Use the model on 🤗 | ||
| You can load and use the model as any other 🤗 models. | ||
| ```python | ||
| from transformers import AutoModel | ||
| model = AutoModel.from_pretrained("ppetrushkov/llama-2-7b-sql-road-test") | ||
| ``` | ||
|
|
||
|
|
||
| ## Citation | ||
| ``` | ||
| @inproceedings{ | ||
| liao2024in, | ||
| title={3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability}, | ||
| author={Baohao Liao and Christof Monz}, | ||
| booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, | ||
| year={2024}, | ||
| url={https://openreview.net/forum?id=rYjYwuM6yH} | ||
| } | ||
| ``` |
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| # Copyright 2025-present the HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import torch | ||
| from datasets import load_dataset | ||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| BitsAndBytesConfig, | ||
| DataCollatorForLanguageModeling, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
|
|
||
| from peft import RoadConfig, get_peft_model, prepare_model_for_kbit_training | ||
|
|
||
|
|
||
| def train_model( | ||
| base_model: str, | ||
| data_path: str, | ||
| output_dir: str, | ||
| batch_size: int, | ||
| num_epochs: int, | ||
| learning_rate: float, | ||
| cutoff_len: int, | ||
| val_set_size: int, | ||
| quantize: bool, | ||
| eval_step: int, | ||
| save_step: int, | ||
| device: str, | ||
| variant: str, | ||
| road_target_modules: str, | ||
| hub_model_id: str, | ||
| push_to_hub: bool, | ||
| ): | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
| hf_token = os.getenv("HF_TOKEN") | ||
|
|
||
| # Setup device | ||
| device = torch.device(device) | ||
| print(f"Using device: {device}") | ||
|
|
||
| # load tokenizer | ||
| tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) | ||
|
|
||
| # IF YOU WANNA QUANTIZE THE MODEL | ||
| if quantize: | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| base_model, | ||
| token=hf_token, | ||
| quantization_config=BitsAndBytesConfig( | ||
| load_in_4bit=True, | ||
| bnb_4bit_compute_dtype=( | ||
| torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 | ||
| ), | ||
| bnb_4bit_use_double_quant=True, | ||
| bnb_4bit_quant_type="nf4", | ||
| ), | ||
| ) | ||
| # setup for quantized training | ||
| model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) | ||
| else: | ||
| model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token, device_map="auto") | ||
| # RoAd config for the PEFT model | ||
| road_config = RoadConfig( | ||
| variant=variant, # Rank of matrix | ||
| target_modules=( | ||
| road_target_modules.split(",") | ||
| if road_target_modules | ||
| else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] | ||
| ), | ||
| ) | ||
|
|
||
| # get the peft model with RoAd config | ||
| model = get_peft_model(model, road_config) | ||
|
|
||
| model.to(device) # MODEL TO GPU/CUDA | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| # Load the dataset | ||
| dataset = load_dataset(data_path) | ||
|
|
||
| def tokenize_function(examples): | ||
| inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len) | ||
| inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task | ||
| return inputs | ||
|
|
||
| # Tokenize the dataset and prepare for training | ||
| tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) | ||
|
|
||
| # Data collator to dynamically pad the batched examples | ||
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) | ||
|
|
||
| # Define training arguments | ||
| training_args = TrainingArguments( | ||
| output_dir=output_dir, | ||
| num_train_epochs=num_epochs, | ||
| per_device_train_batch_size=batch_size, | ||
| per_device_eval_batch_size=batch_size, | ||
| warmup_steps=100, | ||
| weight_decay=0.01, | ||
| logging_dir="./logs", | ||
| logging_steps=eval_step, | ||
| save_steps=save_step, | ||
| save_total_limit=2, | ||
| push_to_hub=push_to_hub, | ||
| hub_model_id=hub_model_id, | ||
| gradient_accumulation_steps=16, | ||
| fp16=True, | ||
| learning_rate=learning_rate, | ||
| hub_token=hf_token, | ||
| ) | ||
|
|
||
| # Clear CUDA cache to free memory | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # Initialize the Trainer | ||
| trainer = Trainer( | ||
| model=model, | ||
| args=training_args, | ||
| train_dataset=tokenized_datasets["train"], | ||
| eval_dataset=tokenized_datasets["test"], | ||
| data_collator=data_collator, | ||
| ) | ||
|
|
||
| # Start model training | ||
| trainer.train() | ||
|
|
||
| # Save and push the trained model and tokenizer | ||
| if push_to_hub: | ||
| # Push the main model to the hub | ||
| trainer.push_to_hub(commit_message="Fine-tuned model") | ||
|
|
||
| # Save the model and tokenizer locally | ||
| model.save_pretrained(output_dir) | ||
| tokenizer.save_pretrained(output_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import argparse | ||
|
|
||
| parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DoRA and PEFT") | ||
| parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name") | ||
| parser.add_argument( | ||
| "--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name" | ||
| ) | ||
| parser.add_argument( | ||
| "--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" | ||
| ) | ||
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
| parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") | ||
| parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate") | ||
| parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization") | ||
| parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size") | ||
| parser.add_argument("--quantize", action="store_true", help="Use quantization") | ||
| parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") | ||
| parser.add_argument("--save_step", type=int, default=100, help="Save step interval") | ||
| parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training") | ||
| parser.add_argument( | ||
| "--variant", type=str, default="road_1", choices=["road_1", "road_2", "road_4"], help="RoAD variant" | ||
| ) | ||
| parser.add_argument( | ||
| "--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for RoAd" | ||
| ) | ||
| parser.add_argument( | ||
| "--hub_model_id", | ||
| type=str, | ||
| default="path/to/repo", | ||
| help="Repository name to push the model on the Hugging Face Hub", | ||
| ) | ||
| parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub") | ||
| args = parser.parse_args() | ||
| train_model( | ||
| base_model=args.base_model, | ||
| data_path=args.data_path, | ||
| output_dir=args.output_dir, | ||
| batch_size=args.batch_size, | ||
| num_epochs=args.num_epochs, | ||
| learning_rate=args.learning_rate, | ||
| cutoff_len=args.cutoff_len, | ||
| val_set_size=args.val_set_size, | ||
| quantize=args.quantize, | ||
| eval_step=args.eval_step, | ||
| save_step=args.save_step, | ||
| device=args.device, | ||
| variant=args.variant, | ||
| road_target_modules=args.road_target_modules, | ||
| hub_model_id=args.hub_model_id, | ||
| push_to_hub=args.push_to_hub, | ||
| ) | ||
12 changes: 12 additions & 0 deletions
12
method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/adapter_config.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| { | ||
| "auto_mapping": null, | ||
| "base_model_name_or_path": null, | ||
| "group_size": 64, | ||
| "inference_mode": false, | ||
| "init_weights": true, | ||
| "peft_type": "ROAD", | ||
| "revision": null, | ||
| "target_modules": null, | ||
| "task_type": null, | ||
| "variant": "road_2" | ||
| } |
5 changes: 5 additions & 0 deletions
5
method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/training_params.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| { | ||
| "optimizer_kwargs": { | ||
| "lr": 1e-3 | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.