From 95456bc414ce68222f45d5d5e9def322e616c2d6 Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Wed, 25 Jun 2025 16:04:48 +0200 Subject: [PATCH 01/10] RoAD adapter implementation (https://arxiv.org/pdf/2409.00119) --- examples/road_finetuning/road_finetuning.py | 187 +++++++++ examples/sequence_classification/RoAD.py | 124 ++++++ src/peft/__init__.py | 4 + src/peft/tuners/__init__.py | 3 + src/peft/tuners/road/__init__.py | 44 +++ src/peft/tuners/road/bnb.py | 407 ++++++++++++++++++++ src/peft/tuners/road/config.py | 123 ++++++ src/peft/tuners/road/layer.py | 393 +++++++++++++++++++ src/peft/tuners/road/model.py | 302 +++++++++++++++ src/peft/utils/peft_types.py | 2 + tests/test_common_gpu.py | 199 ++++++++++ tests/test_config.py | 2 + tests/test_custom_models.py | 39 ++ tests/test_decoder_models.py | 13 +- tests/test_encoder_decoder_models.py | 9 + tests/test_feature_extraction_models.py | 9 + tests/test_seq_classifier.py | 9 + tests/testing_common.py | 1 + 18 files changed, 1869 insertions(+), 1 deletion(-) create mode 100644 examples/road_finetuning/road_finetuning.py create mode 100644 examples/sequence_classification/RoAD.py create mode 100644 src/peft/tuners/road/__init__.py create mode 100644 src/peft/tuners/road/bnb.py create mode 100644 src/peft/tuners/road/config.py create mode 100644 src/peft/tuners/road/layer.py create mode 100644 src/peft/tuners/road/model.py diff --git a/examples/road_finetuning/road_finetuning.py b/examples/road_finetuning/road_finetuning.py new file mode 100644 index 0000000000..ee8005dae1 --- /dev/null +++ b/examples/road_finetuning/road_finetuning.py @@ -0,0 +1,187 @@ +import os + +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + Trainer, + TrainingArguments, +) + +from peft import RoadConfig, get_peft_model, prepare_model_for_kbit_training + + +def train_model( + base_model: str, + data_path: str, + output_dir: str, + batch_size: int, + num_epochs: int, + learning_rate: float, + cutoff_len: int, + val_set_size: int, + quantize: bool, + eval_step: int, + save_step: int, + device: str, + variant: str, + road_target_modules: str, + hub_model_id: str, + push_to_hub: bool, +): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + hf_token = os.getenv("HF_TOKEN") + + # Setup device + device = torch.device(device) + print(f"Using device: {device}") + + # load tokenizer + tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) + + # QDoRA (quantized dora): IF YOU WANNA QUANTIZE THE MODEL + if quantize: + model = AutoModelForCausalLM.from_pretrained( + base_model, + token=hf_token, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=( + torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 + ), + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ), + ) + # setup for quantized training + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) + else: + model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token, device_map="auto") + # LoRa config for the PEFT model + lora_config = RoadConfig( + variant=variant, # Rank of matrix + target_modules=( + road_target_modules.split(",") + if road_target_modules + else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + ), + ) + + # get the peft model with LoRa config + model = get_peft_model(model, lora_config) + + model.to(device) # MODEL TO GPU/CUDA + tokenizer.pad_token = tokenizer.eos_token + + # Load the dataset + dataset = load_dataset(data_path) + + def tokenize_function(examples): + inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len) + inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task + return inputs + + # Tokenize the dataset and prepare for training + tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) + + # Data collator to dynamically pad the batched examples + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + + # Define training arguments + training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + warmup_steps=100, + weight_decay=0.01, + logging_dir="./logs", + logging_steps=eval_step, + save_steps=save_step, + save_total_limit=2, + push_to_hub=push_to_hub, + hub_model_id=hub_model_id, + gradient_accumulation_steps=16, + fp16=True, + learning_rate=learning_rate, + hub_token=hf_token, + ) + + # Clear CUDA cache to free memory + torch.cuda.empty_cache() + + # Initialize the Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets["test"], + data_collator=data_collator, + ) + + # Start model training + trainer.train() + + # Save and push the trained model and tokenizer + if push_to_hub: + # Push the main model to the hub + trainer.push_to_hub(commit_message="Fine-tuned model") + + # Save the model and tokenizer locally + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DoRA and PEFT") + parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name") + parser.add_argument( + "--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name" + ) + parser.add_argument( + "--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" + ) + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate") + parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization") + parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size") + parser.add_argument("--quantize", action="store_true", help="Use quantization") + parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") + parser.add_argument("--save_step", type=int, default=100, help="Save step interval") + parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training") + parser.add_argument("--variant", type=str, default="1", choices=["1", "2", "4"], help="RoAD variant") + parser.add_argument( + "--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA" + ) + parser.add_argument( + "--hub_model_id", + type=str, + default="path/to/repo", + help="Repository name to push the model on the Hugging Face Hub", + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub") + args = parser.parse_args() + train_model( + base_model=args.base_model, + data_path=args.data_path, + output_dir=args.output_dir, + batch_size=args.batch_size, + num_epochs=args.num_epochs, + learning_rate=args.learning_rate, + cutoff_len=args.cutoff_len, + val_set_size=args.val_set_size, + quantize=args.quantize, + eval_step=args.eval_step, + save_step=args.save_step, + device=args.device, + variant=args.variant, + road_target_modules=args.road_target_modules, + hub_model_id=args.hub_model_id, + push_to_hub=args.push_to_hub, + ) diff --git a/examples/sequence_classification/RoAD.py b/examples/sequence_classification/RoAD.py new file mode 100644 index 0000000000..3554d2c5ee --- /dev/null +++ b/examples/sequence_classification/RoAD.py @@ -0,0 +1,124 @@ +import evaluate +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup + +from peft import ( + LoraConfig, + PeftType, + RoadConfig, + get_peft_model, +) + + +use_lora = True + +batch_size = 32 +model_name_or_path = "roberta-large" +task = "mrpc" +if use_lora: + peft_type = PeftType.LORA +else: + peft_type = PeftType.ROAD +device = "cuda" +num_epochs = 20 +max_length = 128 +torch.manual_seed(0) + +if use_lora: + peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1) +else: + peft_config = RoadConfig( + task_type="SEQ_CLS", + variant="2", + target_modules=["query", "key", "value", "dense"], + ) +if use_lora: + lr = 3e-4 +else: + lr = 3e-3 + +if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")): + padding_side = "left" +else: + padding_side = "right" + +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side) +if getattr(tokenizer, "pad_token_id") is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + +datasets = load_dataset("glue", task) +metric = evaluate.load("glue", task) + + +def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=max_length) + return outputs + + +tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], +) + +# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the +# transformers library +tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + +def collate_fn(examples): + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + +# Instantiate dataloaders. +train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size) +eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size +) + +model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None) +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() +# print(model) + +optimizer = AdamW(params=model.parameters(), lr=lr) + +# Instantiate scheduler +lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs), + num_training_steps=(len(train_dataloader) * num_epochs), +) + +model.to(device) + +for epoch in range(num_epochs): + model.train() + for step, batch in enumerate(tqdm(train_dataloader)): + batch.to(device) + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + model.eval() + for step, batch in enumerate(tqdm(eval_dataloader)): + batch.to(device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + predictions, references = predictions, batch["labels"] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + print(f"epoch {epoch}:", eval_metric) diff --git a/src/peft/__init__.py b/src/peft/__init__.py index b2fcbe901f..a73e191569 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -91,6 +91,8 @@ PromptTuningInit, RandLoraConfig, RandLoraModel, + RoadConfig, + RoadModel, ShiraConfig, ShiraModel, TrainableTokensConfig, @@ -188,6 +190,8 @@ "PromptTuningInit", "RandLoraConfig", "RandLoraModel", + "RoadConfig", + "RoadModel", "ShiraConfig", "ShiraModel", "TaskType", diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index f758499e12..bdabbbdf0a 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -41,6 +41,7 @@ from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .randlora import RandLoraConfig, RandLoraModel +from .road import RoadConfig, RoadModel from .shira import ShiraConfig, ShiraModel from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel from .vblora import VBLoRAConfig, VBLoRAModel @@ -96,6 +97,8 @@ "PromptTuningInit", "RandLoraConfig", "RandLoraModel", + "RoadConfig", + "RoadModel", "ShiraConfig", "ShiraModel", "TrainableTokensConfig", diff --git a/src/peft/tuners/road/__init__.py b/src/peft/tuners/road/__init__.py new file mode 100644 index 0000000000..6eea31b440 --- /dev/null +++ b/src/peft/tuners/road/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.utils import register_peft_method + +from .config import RoadConfig +from .layer import Linear, RoadLayer +from .model import RoadModel + + +__all__ = [ + "Linear", + "RoadConfig", + "RoadLayer", + "RoadModel", +] + +register_peft_method(name="road", config_cls=RoadConfig, model_cls=RoadModel, is_mixed_compatible=True) + + +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt + + return Linear8bitLt + + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit + + return Linear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/road/bnb.py b/src/peft/tuners/road/bnb.py new file mode 100644 index 0000000000..0ea9420bf4 --- /dev/null +++ b/src/peft/tuners/road/bnb.py @@ -0,0 +1,407 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional + +import bitsandbytes as bnb +import torch + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.integrations import dequantize_bnb_weight + +from .config import RoadVariant +from .layer import RoadLayer, _apply_road, _get_delta_weight + + +if is_bnb_available(): + + class Linear8bitLt(torch.nn.Module, RoadLayer): + # Road implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + variant: RoadVariant = RoadVariant.ROAD_1, + group_size: int = 64, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + RoadLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + variant=variant, + group_size=group_size, + init_weights=init_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + warnings.warn( + "Merge road module to 8-bit linear may get different generations due to rounding errors." + ) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + + # Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8 + # dequantization directly + output = dequantize_bnb_weight(weight, state=state) + road_R = _get_delta_weight( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter].data, + self.road_alpha[active_adapter].data, + ) + + w_data = torch.matmul(road_R, output.to(road_R.dtype)) + w_data = w_data.to(road_R.dtype).to(road_R.device).contiguous() + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + + if self.get_base_layer().bias is not None: + bias = self.get_base_layer().bias + orig_dtype = bias.dtype + bias_data = bias.data + new_bias = torch.matmul(road_R, bias_data.to(road_R.dtype)) + bias.data = new_bias.to(orig_dtype) + + state.reset_grads() + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + warnings.warn( + "Unmerge road module to 8-bit linear may get different generations due to rounding errors." + ) + + weight = self.get_base_layer().weight + state = self.get_base_layer().state + if state.SCB is None: + state.SCB = weight.SCB + output = dequantize_bnb_weight(weight, state=state) + + road_R = _get_delta_weight( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter].data, + self.road_alpha[active_adapter].data, + ) + inv_road_R = torch.linalg.inv(road_R.to(torch.float32)).to(road_R.dtype) + + w_data = torch.matmul(inv_road_R, output.to(road_R.dtype)) + w_data = w_data.to(road_R.dtype).to(road_R.device).contiguous() + + self.get_base_layer().weight = bnb.nn.Int8Params( + w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights + ).to(weight.device) + + if self.get_base_layer().bias is not None: + bias = self.get_base_layer().bias + orig_dtype = bias.dtype + bias_data = bias.data + new_bias = torch.matmul(inv_road_R, bias_data) + bias.data = new_bias.to(orig_dtype) + + state.reset_grads() + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + result = self._cast_input_dtype(result, self.road_theta[active_adapter].dtype) + + result = _apply_road( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter], + self.road_alpha[active_adapter], + result, + ) + + if requires_conversion: + x = x.to(expected_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "road." + rep + + def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_8bit = kwargs.get("loaded_in_8bit", False) + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + eightbit_kwargs = kwargs.copy() + eightbit_kwargs.update( + { + "has_fp16_weights": target.state.has_fp16_weights, + "threshold": target.state.threshold, + "index": target.index, + } + ) + new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs) + + return new_module + + +if is_bnb_4bit_available(): + + class Linear4bit(torch.nn.Module, RoadLayer): + # OFT implemented in a dense layer + def __init__( + self, + base_layer: torch.nn.Module, + adapter_name: str, + variant: RoadVariant = RoadVariant.ROAD_1, + group_size: int = 64, + init_weights: bool = True, + **kwargs, + ) -> None: + super().__init__() + RoadLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + variant=variant, + group_size=group_size, + init_weights=init_weights, + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + warnings.warn( + "Merge oft module to 4-bit linear may get different generations due to rounding errors." + ) + # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + + output = dequantize_bnb_weight(weight, state=weight.quant_state) + + road_R = _get_delta_weight( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter].data, + self.road_alpha[active_adapter].data, + ) + w_data = torch.matmul(road_R, output.to(road_R.dtype)) + w_data = w_data.to(road_R.dtype).to(road_R.device) + + if safe_merge and not torch.isfinite(w_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False + kwargs["requires_grad"] = False + kwargs.pop("data", None) + # torch.compile can introduce attributes preceded by '_', remove them + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) + + if self.get_base_layer().bias is not None: + bias = self.get_base_layer().bias + orig_dtype = bias.dtype + bias_data = bias.data + new_bias = torch.matmul(road_R, bias_data.to(road_R.dtype)) + bias.data = new_bias.to(orig_dtype) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + warnings.warn( + "Unmerge oft module to 4-bit linear may get different generations due to rounding errors." + ) + + weight = self.get_base_layer().weight + kwargs = weight.__dict__ + output = dequantize_bnb_weight(weight, state=weight.quant_state) + + road_R = _get_delta_weight( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter].data, + self.road_alpha[active_adapter].data, + ) + inv_road_R = torch.linalg.inv(road_R.to(torch.float32)).to(road_R.dtype) + + w_data = torch.matmul(inv_road_R, output.to(road_R.dtype)) + w_data = w_data.to(road_R.dtype).to(road_R.device) + + if "bnb_quantized" in kwargs: + kwargs["bnb_quantized"] = False + kwargs["requires_grad"] = False + kwargs.pop("data", None) + self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) + + if self.get_base_layer().bias is not None: + bias = self.get_base_layer().bias + orig_dtype = bias.dtype + bias_data = bias.data + new_bias = torch.matmul(inv_road_R, bias_data) + bias.data = new_bias.to(orig_dtype) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + # As per Tim Dettmers, for 4bit, we need to defensively clone here. + # The reason is that in some cases, an error can occur that backprop + # does not work on a manipulated view. This issue may be solved with + # newer PyTorch versions but this would need extensive testing to be + # sure. + # result = result.clone() + + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + result = self._cast_input_dtype(result, self.road_theta[active_adapter].dtype) + + result = _apply_road( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter], + self.road_alpha[active_adapter], + result, + ) + if requires_conversion: + x = x.to(expected_dtype) + + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "oft." + rep + + def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + loaded_in_4bit = kwargs.get("loaded_in_4bit", False) + if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + new_module = Linear4bit(target, adapter_name, **fourbit_kwargs) + + return new_module diff --git a/src/peft/tuners/road/config.py b/src/peft/tuners/road/config.py new file mode 100644 index 0000000000..90815b0d44 --- /dev/null +++ b/src/peft/tuners/road/config.py @@ -0,0 +1,123 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +class RoadVariant(str, Enum): + ROAD_1 = "1" + ROAD_2 = "2" + ROAD_4 = "4" + + +@dataclass +class RoadConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`RoadModel`]. Road adapter is proposed in + https://arxiv.org/pdf/2409.00119 . + + Args: + variant (Union[`RoadVariant`, `str`]): + The variant of the Road model to use. It can be one of 1, 2, or 4. + - 1: Road-1 + - 2: Road-2 + - 4: Road-4 + group_size (`int`): + Group size defines how elements are grouped together into 2D vectors for rotation. + Within each group element 0 is paired with element group_size/2, + then element 1 is paired with element group_size/2+1 and so on. + This has no effect on the model performance, since elements are unordered, + however it has some effect on inference speed when used in e.g. VLLM. + For best speed group size of at least 64 is recommended. + Note that model hidden size (or hidden size per partition when used with tensor parallelism) + must be divisible by group_size, so for very small models you might need to reduce this parameter. + init_weights (`bool`): + Whether to perform initialization of Road weights. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen (if + the model is a PreTrainedModel, the output layer excluded). If this is not specified, modules will be + chosen according to the model architecture. If the architecture is not known, an error will be raised -- in + this case, you should specify the target modules manually. + modules_to_save (`List[str]`): + List of modules apart from Road layers to be set as trainable and saved in the final checkpoint. + """ + + variant: Union[str, RoadVariant] = field( + default=RoadVariant.ROAD_1, + metadata={"help": ("Variant of the Road model to use. ")}, + ) + group_size: int = field( + default=64, + metadata={ + "help": ( + "Group size defines how elements are grouped together into 2D vectors for rotation. " + "Within each group element 0 is paired with element group_size/2, " + "then element 1 is paired with element group_size/2+1 and so on. " + "This has no effect on the model performance, since elements are unordered, " + "however it has some effect on inference speed when used in e.g. VLLM. " + "For best speed group size of at least 64 is recommended. " + "Note that model hidden size (or hidden size per partition when used with tensor parallelism) " + "must be divisible by group_size, so for very small models you might need to reduce this parameter." + ) + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the RoAd layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + target_modules: Optional[Union[list[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with Road." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'." + "This can also be a wildcard 'all-linear' which matches all linear/Conv1D " + "(if the model is a PreTrainedModel, the output layer excluded)." + "If not specified, modules will be chosen according to the model architecture, If the architecture is " + "not known, an error will be raised -- in this case, you should specify the target modules manually." + ), + }, + ) + modules_to_save: Optional[list[str]] = field( + default=None, + metadata={ + "help": ( + "List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint. For" + " example, in Sequence Classification or Token Classification tasks, the final layer" + " `classifier/score` are randomly initialized and as such need to be trainable and saved." + ) + }, + ) + + def __post_init__(self): + super().__post_init__() + self.peft_type = PeftType.ROAD + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py new file mode 100644 index 0000000000..c2f342d7f9 --- /dev/null +++ b/src/peft/tuners/road/layer.py @@ -0,0 +1,393 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, Optional, Union + +import torch +import torch.nn as nn + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + +from .config import RoadConfig, RoadVariant + + +class RoadLayer(BaseTunerLayer): + """ + Road layer. + + Generally the idea of RoAD is to split the input vector into many 2D vectors and rotate each 2D vector with its own + 2D rotation matrix. For additional flexibility, each rotation matrix is multiplied by a trainable scale. + + when applied to vector R @ x each pair of elements of x is transformed like this: + y₀ = x₀ * α₀cosθ₀ - xₙ * α₀sinθ₀ + yₙ = x₀ * α₀sinθ₀ + xₙ * α₀cosθ₀ + + The scales and angles inside each rotation matrix may actually be different (when using variant 2 or 4). + + Note that instead of using two consecutive elements x₀ x₁ we pair elements from the first and second half of the + group, which allows for more efficient inference implementation. + + The adapter needs to only store the angles θ and scales α, rather than the full matrix R and the inference + implementation only needs to do elementwise vector multiplications. + + For merging the weights, we make use of the following formula: R @ (W @ x + b) = (R @ W) @ x + R @ b The lhs part + is how it is used in unmerged state (using efficient elementwise implementation instead of matrix multiplication) + and the rhs part is how it is used in merged state where (R @ W) becomes the new weight matrix and R @ b becomes + the new bias. + + """ + + adapter_layer_names: tuple[str, ...] = ("road_theta", "road_alpha") + + def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: + self.base_layer = base_layer + self.variant = {} + self.group_size = {} + self.road_theta = nn.ParameterDict({}) + self.road_alpha = nn.ParameterDict({}) + + self._disable_adapters = False + self.merged_adapters = [] + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + else: + in_features, out_features = None, None + warnings.warn( + f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning + ) + self.in_features = in_features + self.out_features = out_features + + @property + def _available_adapters(self) -> set[str]: + return {*self.road_theta} + + def update_layer( + self, + adapter_name, + variant, + group_size, + init_weights, + ): + # collect the kwargs + kwargs = locals().copy() + del kwargs["self"] + + self.variant[adapter_name] = variant + self.group_size[adapter_name] = group_size + + # Actual trainable parameters + if variant == RoadVariant.ROAD_1: + size = self.out_features // 2 + elif variant == RoadVariant.ROAD_2: + size = self.out_features + elif variant == RoadVariant.ROAD_4: + size = self.out_features * 2 + else: + raise ValueError(f"Unsupported variant {variant} for RoadLayer. Supported variants are 1, 2, and 4.") + self.road_theta[adapter_name] = nn.Parameter(torch.rand(size)) + self.road_alpha[adapter_name] = nn.Parameter(torch.rand(size)) + + # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed + if init_weights: + self.reset_road_parameters(adapter_name, init_weights) + self._move_adapter_to_device_of_base_layer(adapter_name) + + self.set_adapter(self.active_adapters) + + def reset_road_parameters(self, adapter_name, init_weights): + if init_weights is False: + return + nn.init.zeros_(self.road_theta[adapter_name].data) + nn.init.ones_(self.road_alpha[adapter_name].data) + + +class Linear(nn.Module, RoadLayer): + # Road implemented in a dense layer + def __init__( + self, + base_layer, + adapter_name: str, + variant: RoadVariant = RoadVariant.ROAD_1, + group_size: int = 64, + init_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + RoadLayer.__init__(self, base_layer, **kwargs) + + self._active_adapter = adapter_name + + self.update_layer( + adapter_name, + variant, + group_size, + init_weights=init_weights, + ) + + def _check_forward_args(self, x, *args, **kwargs): + """Check if the arguments are compatible with the configs and state of the model""" + adapter_names = kwargs.get("adapter_names", None) + if adapter_names is None: + return + + if len(x) != len(adapter_names): + msg = ( + "Length of `adapter_names` should be the same as the number of inputs, but got " + f"{len(adapter_names)} and {len(x)} respectively." + ) + raise ValueError(msg) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + + for active_adapter in self.active_adapters: + if active_adapter not in self._available_adapters: + continue + + result = self._cast_input_dtype(result, self.road_theta[active_adapter].dtype) + result = _apply_road( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter], + self.road_alpha[active_adapter], + result, + ) + + result = result.to(torch_result_dtype) + + return result + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self._available_adapters: + continue + + dtype = self.road_theta[active_adapter].data.dtype + + # getting the sub-batch, passing it to Road layers and updating the corresponding indices of the linear + # layer output + sub_batch = result[sub_batch_indices_list[i]].to(dtype) + result[sub_batch_indices_list[i]] = _apply_road( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter], + self.road_alpha[active_adapter], + sub_batch, + ) + + return result + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self._available_adapters: + base_layer = self.get_base_layer() + orig_dtype = base_layer.weight.dtype + road_R = _get_delta_weight( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter].data, + self.road_alpha[active_adapter].data, + ) + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + orig_weight = torch.matmul(road_R.to(orig_dtype), orig_weight) + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weight.contiguous().to(orig_dtype) + + if base_layer.bias is not None: + orig_bias = base_layer.bias.clone() + orig_bias = torch.matmul(road_R.to(orig_dtype), orig_bias) + base_layer.bias.data = orig_bias.contiguous().to(orig_dtype) + else: + orig_weight = base_layer.weight.data + orig_weight = torch.matmul(road_R.to(orig_dtype), orig_weight) + base_layer.weight.data = orig_weight.contiguous().to(orig_dtype) + + if base_layer.bias is not None: + orig_bias = base_layer.bias.data + orig_bias = torch.matmul(road_R.to(orig_dtype), orig_bias) + base_layer.bias.data = orig_bias.contiguous().to(orig_dtype) + + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + # Going in reverse order + active_adapter = self.merged_adapters.pop() + if active_adapter in self._available_adapters: + weight = self.get_base_layer().weight + orig_dtype = weight.dtype + road_R = _get_delta_weight( + self.variant[active_adapter], + self.group_size[active_adapter], + self.road_theta[active_adapter].data, + self.road_alpha[active_adapter].data, + ) + # Since our matrix are not necessarily orthogonal we need inverse instead of transpose + inv_road_R = torch.linalg.inv(road_R.to(torch.float32)).to(orig_dtype) + orig_weight = torch.matmul(inv_road_R, weight.data) + weight.data = orig_weight.contiguous() + + if self.get_base_layer().bias is not None: + orig_bias = torch.matmul(inv_road_R, self.get_base_layer().bias.data) + self.get_base_layer().bias.data = orig_bias.contiguous() + + def __repr__(self) -> str: + rep = super().__repr__() + return "road." + rep + + +def _get_delta_weight(variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor): + first_col, second_col = _prepare_cols(variant, group_size, road_theta, road_alpha) + + # First column is simply put on the main diagonal + output_tensor = torch.diag(first_col) + # For second column we need to swap each half groups and add minus sign + size = second_col.shape[0] + swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten() + rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :] + rotated_diag_second_col[:, 0, :, :] *= -1 + rotated_diag_second_col = rotated_diag_second_col.reshape(size, size) + output_tensor += rotated_diag_second_col + + return output_tensor + + +def _prepare_cols( + variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # In inference mode, this can be cached + if variant == RoadVariant.ROAD_1: + # In each group there are only group_size // 2 parameters that are reused + road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten() + road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten() + + theta_cos = road_theta.cos() + theta_sin = road_theta.sin() + + first_col = road_alpha * theta_cos + second_col = road_alpha * theta_sin + elif variant == RoadVariant.ROAD_2: + # Each group has exactly group_size parameters + theta_cos = road_theta.cos() + theta_sin = road_theta.sin() + + first_col = road_alpha * theta_cos + second_col = road_alpha * theta_sin + elif variant == RoadVariant.ROAD_4: + # Each group has 2*group_size parameters, first half used for first column, second half for second column + road_theta = road_theta.reshape(-1, 2, group_size) + theta_cos = road_theta[:, 0, :].cos().flatten() + theta_sin = road_theta[:, 1, :].sin().flatten() + road_alpha = road_alpha.reshape(-1, 2, group_size) + alpha_1 = road_alpha[:, 0, :].flatten() + alpha_2 = road_alpha[:, 1, :].flatten() + + first_col = alpha_1 * theta_cos + second_col = alpha_2 * theta_sin + + return first_col, second_col + + +def _apply_road( + variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor, x: torch.Tensor +): + first_col, second_col = _prepare_cols(variant, group_size, road_theta, road_alpha) + + # Split in half groups and join back + x_grouped = x.reshape(-1, 2, group_size // 2) + x1 = x_grouped[:, 0, :] + x2 = x_grouped[:, 1, :] + rotate_half_x = torch.stack((-x2, x1), dim=1).reshape(x.shape) + result = x * first_col + rotate_half_x * second_col + return result + + +def dispatch_default( + target: torch.nn.Module, + adapter_name: str, + road_config: RoadConfig, + **kwargs, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + new_module = Linear(target, adapter_name, **kwargs) + + return new_module diff --git a/src/peft/tuners/road/model.py b/src/peft/tuners/road/model.py new file mode 100644 index 0000000000..1c19796665 --- /dev/null +++ b/src/peft/tuners/road/model.py @@ -0,0 +1,302 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import operator +from typing import Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.road.config import RoadConfig +from peft.tuners.tuners_utils import ( + BaseTuner, + BaseTunerLayer, + check_target_module_exists, + onload_layer, +) +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .layer import RoadLayer, dispatch_default + + +class RoadModel(BaseTuner): + """ """ + + prefix: str = "road_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + @staticmethod + def _prepare_adapter_config(road_config: RoadConfig, model_config: dict) -> RoadConfig: + if road_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + road_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return road_config + + @staticmethod + def _check_target_module_exists(road_config, key): + return check_target_module_exists(road_config, key) + + def _create_and_replace( + self, + road_config: RoadConfig, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + current_key, + ) -> None: + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + variant = road_config.variant + group_size = road_config.group_size + + kwargs = { + "variant": variant, + "group_size": group_size, + "init_weights": road_config.init_weights, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + # for torchao merging, we need the get_apply_tensor_subclass from the quantization config + try: + kwargs["get_apply_tensor_subclass"] = operator.attrgetter( + "hf_quantizer.quantization_config.get_apply_tensor_subclass" + )(self.model) + except AttributeError: + pass + + if isinstance(target, RoadLayer): + target.update_layer( + adapter_name, + variant, + group_size, + init_weights=road_config.init_weights, + ) + else: + device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None + new_module = self._create_new_module(road_config, adapter_name, target, device_map=device_map, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + meta = torch.device("meta") + # dispatch to correct device + for name, module in new_module.named_modules(): + if (self.prefix in name) or ("ranknum" in name): + if hasattr(child, "qweight"): + weight = child.qweight + elif hasattr(child, "W_q"): + weight = child.W_q + elif hasattr(child, "weight"): + weight = child.weight + elif getattr(child, "in_proj_weight", None) is not None: # MHA + weight = child.in_proj_weight + else: + weight = next(child.parameters()) + if not any(p.device == meta for p in module.parameters()): + module.to(weight.device) + + @staticmethod + def _create_new_module(road_config: RoadConfig, adapter_name, target, **kwargs): + dispatchers = [] + + # avoid eager bnb import + if is_bnb_available(): + from .bnb import dispatch_bnb_8bit + + dispatchers.append(dispatch_bnb_8bit) + + if is_bnb_4bit_available(): + from .bnb import dispatch_bnb_4bit + + dispatchers.append(dispatch_bnb_4bit) + + dispatchers.extend( + [ + dispatch_default, + ] + ) + + new_module = None + for dispatcher in dispatchers: + new_module = dispatcher(target, adapter_name, road_config=road_config, **kwargs) + if new_module is not None: # first match wins + break + + if new_module is None: + # no module could be matched + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear` " + ) + + return new_module + + def _mark_only_adapters_as_trainable(self, model: nn.Module): + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + def _set_adapter_layers(self, enabled: bool = True) -> None: + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def disable_adapter_layers(self) -> None: + self._set_adapter_layers(enabled=False) + + def enable_adapter_layers(self) -> None: + self._set_adapter_layers(enabled=True) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, RoadLayer): + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.model, name) + + # def _check_merge_allowed(self): + # raise ValueError("Road adapters do not support merging") + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + self._check_merge_allowed() + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + with onload_layer(target): + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + return self.model + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, RoadLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter) + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the RoAd layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the oft modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index 6e4aeae248..2c78ae9f50 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -43,6 +43,7 @@ class PeftType(str, enum.Enum): - RANDLORA - SHIRA - C3A + - ROAD """ PROMPT_TUNING = "PROMPT_TUNING" @@ -67,6 +68,7 @@ class PeftType(str, enum.Enum): CPT = "CPT" BONE = "BONE" RANDLORA = "RANDLORA" + ROAD = "ROAD" TRAINABLE_TOKENS = "TRAINABLE_TOKENS" SHIRA = "SHIRA" C3A = "C3A" diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index a7c810aa8a..6eba926ebf 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -48,6 +48,7 @@ OFTConfig, PeftModel, RandLoraConfig, + RoadConfig, TaskType, VBLoRAConfig, VeraConfig, @@ -74,12 +75,14 @@ from peft.tuners.ia3 import Linear8bitLt as IA3Linear8bitLt from peft.tuners.lora import Linear8bitLt as LoraLinear8bitLt from peft.tuners.randlora import Linear8bitLt as RandLoraLinear8bitLt + from peft.tuners.road import Linear8bitLt as RoadLinear8bitLt from peft.tuners.vera import Linear8bitLt as VeraLinear8bitLt if is_bnb_4bit_available(): from peft.tuners.ia3 import Linear4bit as IA3Linear4bit from peft.tuners.lora import Linear4bit as LoraLinear4bit from peft.tuners.randlora import Linear4bit as RandLoraLinear4bit + from peft.tuners.road import Linear4bit as RoadLinear4bit from peft.tuners.vera import Linear4bit as VeraLinear4bit @@ -292,6 +295,49 @@ def test_ia3_bnb_8bit_quantization(self): whisper_8bit = get_peft_model(whisper_8bit, config) assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear8bitLt) + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_road_bnb_8bit_quantization(self): + r""" + Test that tests if the 8bit quantization using Road works as expected + """ + whisper_8bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + opt_8bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + flan_8bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + flan_road_config = RoadConfig(target_modules=["q", "v"], task_type="SEQ_2_SEQ_LM") + + opt_road_config = RoadConfig( + target_modules=["q_proj", "v_proj", "fc2"], + task_type="CAUSAL_LM", + ) + + config = RoadConfig(target_modules=["q_proj", "v_proj", "fc2"]) + + flan_8bit = get_peft_model(flan_8bit, flan_road_config) + assert isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RoadLinear8bitLt) + + opt_8bit = get_peft_model(opt_8bit, opt_road_config) + assert isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear8bitLt) + + whisper_8bit = get_peft_model(whisper_8bit, config) + assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear8bitLt) + @require_bitsandbytes @pytest.mark.multi_gpu_tests @pytest.mark.single_gpu_tests @@ -697,6 +743,49 @@ def test_ia3_bnb_4bit_quantization(self): whisper_4bit = get_peft_model(whisper_4bit, config) assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear4bit) + @require_bitsandbytes + @pytest.mark.multi_gpu_tests + @pytest.mark.single_gpu_tests + def test_road_bnb_4bit_quantization(self): + r""" + Test that tests if the 4bit quantization using IA3 works as expected + """ + whisper_4bit = WhisperForConditionalGeneration.from_pretrained( + self.audio_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + opt_4bit = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + flan_4bit = AutoModelForSeq2SeqLM.from_pretrained( + self.seq2seq_model_id, + device_map="auto", + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + flan_road_config = RoadConfig(target_modules=["q", "v"], task_type="SEQ_2_SEQ_LM") + + opt_road_config = RoadConfig( + target_modules=["q_proj", "v_proj", "fc2"], + task_type="CAUSAL_LM", + ) + + config = RoadConfig(target_modules=["q_proj", "v_proj", "fc2"]) + + flan_4bit = get_peft_model(flan_4bit, flan_road_config) + assert isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RoadLinear4bit) + + opt_4bit = get_peft_model(opt_4bit, opt_road_config) + assert isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear4bit) + + whisper_4bit = get_peft_model(whisper_4bit, config) + assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear4bit) + @pytest.mark.multi_gpu_tests @require_torch_multi_accelerator def test_lora_causal_lm_multi_gpu_inference(self): @@ -1520,6 +1609,98 @@ def test_dora_ephemeral_gpu_offload_multigpu(self): layer.lora_A, layer.lora_B = la, lb layer.lora_variant[adapter_name].init(layer, adapter_name=adapter_name) # should not raise an error + @require_non_cpu + @pytest.mark.single_gpu_tests + @require_bitsandbytes + def test_8bit_road_merging(self): + # Check results for merging, unmerging, unloading + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + torch_dtype=torch.float32, + ).eval() + + random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) + # compare outputs in probability space, because logits can have outliers + # and token ids are not precise enough + out_base = F.softmax(model(random_input).logits, dim=-1) + + config = RoadConfig( + init_weights=False, + ) + model = get_peft_model(model, config).eval() + + with torch.inference_mode(): + out_road = F.softmax(model(random_input).logits, dim=-1) + + model.merge_adapter() + out_merged = F.softmax(model(random_input).logits, dim=-1) + + model.unmerge_adapter() + out_unmerged = F.softmax(model(random_input).logits, dim=-1) + + model = model.merge_and_unload() + out_unloaded = F.softmax(model(random_input).logits, dim=-1) + + atol = 1e-3 + rtol = 1 + # sanity check that using DoRA changes the results + assert not torch.allclose(out_base, out_road, atol=atol, rtol=rtol) + assert torch.allclose(out_road, out_merged, atol=atol, rtol=rtol) + assert torch.allclose(out_road, out_unmerged, atol=atol, rtol=rtol) + assert torch.allclose(out_road, out_unloaded, atol=atol, rtol=rtol) + + @require_non_cpu + @pytest.mark.single_gpu_tests + @require_bitsandbytes + def test_4bit_road_merging(self): + # Check results for merging, unmerging, unloading + torch.manual_seed(0) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=False, + bnb_4bit_compute_dtype=torch.float32, + ) + model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-random-LlamaForCausalLM", + quantization_config=bnb_config, + torch_dtype=torch.float32, + ).eval() + random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) + # compare outputs in probability space, because logits can have outliers + # and token ids are not precise enough + out_base = model(random_input).logits + probs_base = F.softmax(out_base, dim=-1) + + config = RoadConfig( + init_weights=False, + group_size=4, + ) + model = get_peft_model(model, config).eval() + + with torch.inference_mode(): + out_road = model(random_input).logits + probs_road = F.softmax(out_road, dim=-1) + + model.merge_adapter() + probs_merged = F.softmax(model(random_input).logits, dim=-1) + + model.unmerge_adapter() + probs_unmerged = F.softmax(model(random_input).logits, dim=-1) + + model = model.merge_and_unload() + probs_unloaded = F.softmax(model(random_input).logits, dim=-1) + + atol = 1e-5 + rtol = 1e-3 + # sanity check that using DoRA changes the results + # we compare outputs instead of logits because they may not be sensitive enough + assert not torch.allclose(out_base, out_road, atol=atol, rtol=rtol) + assert torch.allclose(probs_road, probs_merged, atol=atol, rtol=rtol) + assert torch.allclose(probs_road, probs_unmerged, atol=atol, rtol=rtol) + assert torch.allclose(probs_road, probs_unloaded, atol=atol, rtol=rtol) + def test_apply_GS_hra_inference(self): # check for different result with and without apply_GS model = AutoModelForCausalLM.from_pretrained( @@ -1984,3 +2165,21 @@ def test_hra_add_new_adapter_does_not_change_device(self, mlp): # the rest should be on GPU assert model.lin0.base_layer.weight.device.type == self.device assert model.lin0.hra_u.other.device.type == self.device + + def test_road_add_new_adapter_does_not_change_device(self, mlp): + # same as first test, but using HRA + config = RoadConfig(target_modules=["lin0"]) + model = get_peft_model(mlp, config) + model = model.to(self.device) + model.lin0.road_theta.cpu() + + # check that the adapter is indeed on CPU and the base model on GPU + assert model.lin0.road_theta.default.device.type == "cpu" + assert model.lin0.base_layer.weight.device.type == self.device + + model.add_adapter("other", config) + # check that after adding a new adapter, the old adapter is still on CPU + assert model.lin0.road_theta.default.device.type == "cpu" + # the rest should be on GPU + assert model.lin0.base_layer.weight.device.type == self.device + assert model.lin0.road_theta.other.device.type == self.device diff --git a/tests/test_config.py b/tests/test_config.py index 179496b6f3..eddeb46244 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -40,6 +40,7 @@ PromptEncoder, PromptEncoderConfig, PromptTuningConfig, + RoadConfig, TaskType, VBLoRAConfig, VeraConfig, @@ -65,6 +66,7 @@ (PrefixTuningConfig, {}), (PromptEncoderConfig, {}), (PromptTuningConfig, {}), + (RoadConfig, {}), (VeraConfig, {}), (VBLoRAConfig, {}), ) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index e88002f787..d93654d8ae 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -46,6 +46,7 @@ OFTConfig, PeftModel, RandLoraConfig, + RoadConfig, ShiraConfig, TaskType, TrainableTokensConfig, @@ -683,6 +684,15 @@ "modules_to_save": ["lin1"], }, ), + ######## + # RoAd # + ######## + ("Vanilla MLP 1 RoAd", "MLP", RoadConfig, {"target_modules": "lin0", "group_size": 2}), + ("Vanilla MLP 2 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "group_size": 2}), + ("Vanilla MLP 3 RoAd", "MLP", RoadConfig, {"target_modules": ["lin1"], "group_size": 2}), + ("Vanilla MLP 4 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0", "lin1"], "group_size": 2}), + ("Vanilla MLP 5 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "2", "group_size": 2}), + ("Vanilla MLP 6 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "4", "group_size": 2}), ] # For this test matrix, each tuple consists of: @@ -881,6 +891,34 @@ {"target_modules": ["lin0"], "init_weights": False, "boft_block_size": 2}, {"target_modules": ["lin1"], "init_weights": False, "boft_block_size": 2}, ), + ( + "RoAd Same", + "road", + RoadConfig, + {"target_modules": ["lin0"], "init_weights": False, "group_size": 2}, + {"target_modules": ["lin0"], "init_weights": False, "group_size": 2}, + ), + ( + "RoAd Different", + "road", + RoadConfig, + {"target_modules": ["lin0"], "init_weights": False, "group_size": 2}, + {"target_modules": ["lin1"], "init_weights": False, "group_size": 2}, + ), + ( + "RoAd 2 Different", + "road", + RoadConfig, + {"target_modules": ["lin0"], "init_weights": False, "variant": "1", "group_size": 2}, + {"target_modules": ["lin1"], "init_weights": False, "variant": "2", "group_size": 2}, + ), + ( + "RoAd 4 Different", + "road", + RoadConfig, + {"target_modules": ["lin0"], "init_weights": False, "variant": "1", "group_size": 2}, + {"target_modules": ["lin1"], "init_weights": False, "variant": "4", "group_size": 2}, + ), ] PREFIXES = { @@ -899,6 +937,7 @@ ShiraConfig: "shira_", VBLoRAConfig: "vblora_", BoneConfig: "bone_", + RoadConfig: "road_", TrainableTokensConfig: "trainable_tokens_", } diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 3e756c2f43..605625fcb8 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -40,6 +40,7 @@ PromptEncoderConfig, PromptTuningConfig, PromptTuningInit, + RoadConfig, ShiraConfig, VBLoRAConfig, VeraConfig, @@ -181,6 +182,14 @@ "num_virtual_tokens": 10, }, ), + ( + RoadConfig, + { + "task_type": "CAUSAL_LM", + "variant": "1", + "group_size": 2, + }, + ), ( ShiraConfig, { @@ -230,10 +239,11 @@ def _skip_if_not_conv1d_supported(model_id, config_cls): BoneConfig, HRAConfig, OFTConfig, + RoadConfig, ShiraConfig, C3AConfig, ]: - pytest.skip("Skipping BOFT/HRA/OFT/Bone/SHiRA/C3A for GPT2LMHeadModel") + pytest.skip("Skipping BOFT/HRA/OFT/Bone/Road/SHiRA/C3A for GPT2LMHeadModel") def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls): @@ -244,6 +254,7 @@ def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls): OFTConfig, BoneConfig, C3AConfig, + RoadConfig, ]: pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone for GPT2LMHeadModel") diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 3fca67683d..5f74d7e6a8 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -30,6 +30,7 @@ PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, + RoadConfig, ShiraConfig, TaskType, VBLoRAConfig, @@ -146,6 +147,14 @@ "task_type": "SEQ_2_SEQ_LM", }, ), + ( + RoadConfig, + { + "task_type": "SEQ_2_SEQ_LM", + "variant": "1", + "group_size": 2, + }, + ), ( ShiraConfig, { diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 1a9659355c..787ad75b4a 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -29,6 +29,7 @@ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, + RoadConfig, ShiraConfig, VBLoRAConfig, VeraConfig, @@ -146,6 +147,14 @@ "num_virtual_tokens": 10, }, ), + ( + RoadConfig, + { + "task_type": "FEATURE_EXTRACTION", + "variant": "1", + "group_size": 2, + }, + ), ( ShiraConfig, { diff --git a/tests/test_seq_classifier.py b/tests/test_seq_classifier.py index 23d9067b86..0a02dc4bae 100644 --- a/tests/test_seq_classifier.py +++ b/tests/test_seq_classifier.py @@ -29,6 +29,7 @@ PromptEncoderConfig, PromptTuningConfig, PromptTuningInit, + RoadConfig, ShiraConfig, VBLoRAConfig, VeraConfig, @@ -146,6 +147,14 @@ "num_virtual_tokens": 10, }, ), + ( + RoadConfig, + { + "task_type": "SEQ_CLS", + "variant": "1", + "group_size": 2, + }, + ), ( ShiraConfig, { diff --git a/tests/testing_common.py b/tests/testing_common.py index 84fbdee54a..a63aebdce4 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -1614,6 +1614,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): "SHIRA", "BONE", "C3A", + "ROAD", ): with pytest.raises(AttributeError): model = model.unload() From 2df411fca8d37a04272fae073b91f81e2cc921bb Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Mon, 28 Jul 2025 23:40:01 +0200 Subject: [PATCH 02/10] Cleanup and documentation --- docs/source/package_reference/road.md | 29 ++++ examples/road_finetuning/README.md | 90 +++++++++++++ examples/road_finetuning/road_finetuning.py | 12 +- examples/sequence_classification/RoAD.py | 124 ------------------ .../llama-3.2-3B-variant2/adapter_config.json | 12 ++ .../training_params.json | 5 + src/peft/tuners/road/__init__.py | 2 +- src/peft/tuners/road/config.py | 4 +- src/peft/tuners/road/layer.py | 19 ++- src/peft/tuners/road/model.py | 7 +- 10 files changed, 159 insertions(+), 145 deletions(-) create mode 100644 docs/source/package_reference/road.md create mode 100644 examples/road_finetuning/README.md delete mode 100644 examples/sequence_classification/RoAD.py create mode 100644 method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json create mode 100644 method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/training_params.json diff --git a/docs/source/package_reference/road.md b/docs/source/package_reference/road.md new file mode 100644 index 0000000000..e929028b29 --- /dev/null +++ b/docs/source/package_reference/road.md @@ -0,0 +1,29 @@ + + +# RoAd + +[RoAd](https://arxiv.org/pdf/2409.00119) is a parameter‑efficient fine‑tuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions, achieving competitive or superior performance with under 0.1% trainable parameters. Unlike LoRA’s batched low‑rank updates, RoAd’s sparse rotations reformulate to simple element‑wise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, enabling interpretable, composable task‑specific adaptations by combining orthogonal subspaces learned for different tasks. + +Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. + +## RoadConfig + +[[autodoc]] tuners.road.config.RoadConfig + +## RoadModel + +[[autodoc]] tuners.road.model.RoadModel diff --git a/examples/road_finetuning/README.md b/examples/road_finetuning/README.md new file mode 100644 index 0000000000..e7e8b0d9df --- /dev/null +++ b/examples/road_finetuning/README.md @@ -0,0 +1,90 @@ +# RoAd: 3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability + + +## Introduction +[RoAd](https://arxiv.org/pdf/2409.00119) a novel method +which employs a straightforward 2D rotation to adapt LLMs which is +remarkably parameter-efficient, delivering good +performance with < 0.1% trainable parameters; efficient +in serving requests requiring different adapters within a batch, with an overhead +comparable to element-wise multiplication instead of batch matrix multiplication; +enhances LLM’s interpretability. + +## Quick start +```python +import torch +from peft import RoadConfig, get_peft_model +from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer +from datasets import load_dataset + +model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda") +tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") +dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") +road_config = RoadConfig( + variant="1", +) +peft_model = get_peft_model(model, road_config) +trainer = transformers.Trainer( + model=peft_model, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=2048, + tokenizer=tokenizer, +) +trainer.train() +peft_model.save_pretrained("road-llama-3-8b") +``` + +RoAd requires a higher learning rate compared to LoRa and similar approaches, set it to around 1e-3. + +Run the finetuning script simply by running: + +```bash +python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco +``` + +RoAd also supports quantization. To use 4-bit quantization try: + +```bash +python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --quantize +``` + +### Full example of the script +```bash +python road_finetuning.py \ + --base_model "PATH_TO_MODEL" \ + --data_path "PATH_TO_DATASET" \ + --output_dir "PATH_TO_OUTPUT_DIR" \ + --batch_size 1 \ + --num_epochs 3 \ + --learning_rate 1e-3 \ + --cutoff_len 512 \ + --val_set_size 500 \ + --quantize \ + --eval_step 10 \ + --save_step 100 \ + --device "cuda:0" \ + --variant 1 \ + --road_target_modules "q_proj,k_proj,v_proj,o_proj" \ + --hub_model_id "YOUR_HF_REPO" \ + --push_to_hub +``` +## Use the model on 🤗 +You can load and use the model as any other 🤗 models. +```python +from transformers import AutoModel +model = AutoModel.from_pretrained("ppetrushkov/llama-2-7b-sql-road-test") +``` + + +## Citation +``` +@inproceedings{ + liao2024in, + title={3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability}, + author={Baohao Liao and Christof Monz}, + booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, + year={2024}, + url={https://openreview.net/forum?id=rYjYwuM6yH} +} +``` diff --git a/examples/road_finetuning/road_finetuning.py b/examples/road_finetuning/road_finetuning.py index ee8005dae1..30e41a6a53 100644 --- a/examples/road_finetuning/road_finetuning.py +++ b/examples/road_finetuning/road_finetuning.py @@ -42,7 +42,7 @@ def train_model( # load tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) - # QDoRA (quantized dora): IF YOU WANNA QUANTIZE THE MODEL + # IF YOU WANNA QUANTIZE THE MODEL if quantize: model = AutoModelForCausalLM.from_pretrained( base_model, @@ -60,8 +60,8 @@ def train_model( model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) else: model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token, device_map="auto") - # LoRa config for the PEFT model - lora_config = RoadConfig( + # RoAd config for the PEFT model + road_config = RoadConfig( variant=variant, # Rank of matrix target_modules=( road_target_modules.split(",") @@ -70,8 +70,8 @@ def train_model( ), ) - # get the peft model with LoRa config - model = get_peft_model(model, lora_config) + # get the peft model with RoAd config + model = get_peft_model(model, road_config) model.to(device) # MODEL TO GPU/CUDA tokenizer.pad_token = tokenizer.eos_token @@ -157,7 +157,7 @@ def tokenize_function(examples): parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training") parser.add_argument("--variant", type=str, default="1", choices=["1", "2", "4"], help="RoAD variant") parser.add_argument( - "--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA" + "--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for RoAd" ) parser.add_argument( "--hub_model_id", diff --git a/examples/sequence_classification/RoAD.py b/examples/sequence_classification/RoAD.py deleted file mode 100644 index 3554d2c5ee..0000000000 --- a/examples/sequence_classification/RoAD.py +++ /dev/null @@ -1,124 +0,0 @@ -import evaluate -import torch -from datasets import load_dataset -from torch.optim import AdamW -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup - -from peft import ( - LoraConfig, - PeftType, - RoadConfig, - get_peft_model, -) - - -use_lora = True - -batch_size = 32 -model_name_or_path = "roberta-large" -task = "mrpc" -if use_lora: - peft_type = PeftType.LORA -else: - peft_type = PeftType.ROAD -device = "cuda" -num_epochs = 20 -max_length = 128 -torch.manual_seed(0) - -if use_lora: - peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1) -else: - peft_config = RoadConfig( - task_type="SEQ_CLS", - variant="2", - target_modules=["query", "key", "value", "dense"], - ) -if use_lora: - lr = 3e-4 -else: - lr = 3e-3 - -if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")): - padding_side = "left" -else: - padding_side = "right" - -tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side) -if getattr(tokenizer, "pad_token_id") is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - -datasets = load_dataset("glue", task) -metric = evaluate.load("glue", task) - - -def tokenize_function(examples): - # max_length=None => use the model max length (it's actually the default) - outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=max_length) - return outputs - - -tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - remove_columns=["idx", "sentence1", "sentence2"], -) - -# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the -# transformers library -tokenized_datasets = tokenized_datasets.rename_column("label", "labels") - - -def collate_fn(examples): - return tokenizer.pad(examples, padding="longest", return_tensors="pt") - - -# Instantiate dataloaders. -train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size) -eval_dataloader = DataLoader( - tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size -) - -model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None) -model = get_peft_model(model, peft_config) -model.print_trainable_parameters() -# print(model) - -optimizer = AdamW(params=model.parameters(), lr=lr) - -# Instantiate scheduler -lr_scheduler = get_linear_schedule_with_warmup( - optimizer=optimizer, - num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs), - num_training_steps=(len(train_dataloader) * num_epochs), -) - -model.to(device) - -for epoch in range(num_epochs): - model.train() - for step, batch in enumerate(tqdm(train_dataloader)): - batch.to(device) - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - model.eval() - for step, batch in enumerate(tqdm(eval_dataloader)): - batch.to(device) - with torch.no_grad(): - outputs = model(**batch) - predictions = outputs.logits.argmax(dim=-1) - predictions, references = predictions, batch["labels"] - metric.add_batch( - predictions=predictions, - references=references, - ) - - eval_metric = metric.compute() - print(f"epoch {epoch}:", eval_metric) diff --git a/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json new file mode 100644 index 0000000000..3674b02ffe --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json @@ -0,0 +1,12 @@ +{ + "auto_mapping": null, + "base_model_name_or_path": null, + "group_size": 64, + "inference_mode": false, + "init_weights": true, + "peft_type": "ROAD", + "revision": null, + "target_modules": null, + "task_type": null, + "variant": "1" +} diff --git a/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/training_params.json b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/training_params.json new file mode 100644 index 0000000000..52d87e3ef6 --- /dev/null +++ b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/training_params.json @@ -0,0 +1,5 @@ +{ + "optimizer_kwargs": { + "lr": 1e-3 + } +} diff --git a/src/peft/tuners/road/__init__.py b/src/peft/tuners/road/__init__.py index 6eea31b440..7032297c41 100644 --- a/src/peft/tuners/road/__init__.py +++ b/src/peft/tuners/road/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/peft/tuners/road/config.py b/src/peft/tuners/road/config.py index 90815b0d44..abaa1e5d2a 100644 --- a/src/peft/tuners/road/config.py +++ b/src/peft/tuners/road/config.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -108,7 +108,7 @@ class RoadConfig(PeftConfig): default=None, metadata={ "help": ( - "List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint. For" + "List of modules apart from RoAd layers to be set as trainable and saved in the final checkpoint. For" " example, in Sequence Classification or Token Classification tasks, the final layer" " `classifier/score` are randomly initialized and as such need to be trainable and saved." ) diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py index c2f342d7f9..d05842e266 100644 --- a/src/peft/tuners/road/layer.py +++ b/src/peft/tuners/road/layer.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,18 +31,19 @@ class RoadLayer(BaseTunerLayer): 2D rotation matrix. For additional flexibility, each rotation matrix is multiplied by a trainable scale. when applied to vector R @ x each pair of elements of x is transformed like this: - y₀ = x₀ * α₀cosθ₀ - xₙ * α₀sinθ₀ - yₙ = x₀ * α₀sinθ₀ + xₙ * α₀cosθ₀ + y₀ = x₀ * α * cosθ - xₙ * α * sinθ + yₙ = x₀ * α * sinθ + xₙ * α * cosθ - The scales and angles inside each rotation matrix may actually be different (when using variant 2 or 4). + The scales α and angles θ are learned for each pair of elements and, moreover, each of the 4 instances in the + rotation matrix may actually be different (when using variant 2 or 4). - Note that instead of using two consecutive elements x₀ x₁ we pair elements from the first and second half of the - group, which allows for more efficient inference implementation. + Note that instead of using two consecutive elements x₀ x₁ we first split the whole vector into groups and pair + elements from the first with the second half of the same group, which allows for more efficient inference implementation. The adapter needs to only store the angles θ and scales α, rather than the full matrix R and the inference implementation only needs to do elementwise vector multiplications. - For merging the weights, we make use of the following formula: R @ (W @ x + b) = (R @ W) @ x + R @ b The lhs part + For merging the weights, we make use of the following formula: R @ (W @ x + b) = (R @ W) @ x + R @ b. The lhs part is how it is used in unmerged state (using efficient elementwise implementation instead of matrix multiplication) and the rhs part is how it is used in merged state where (R @ W) becomes the new weight matrix and R @ b becomes the new bias. @@ -311,6 +312,10 @@ def __repr__(self) -> str: def _get_delta_weight(variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor): first_col, second_col = _prepare_cols(variant, group_size, road_theta, road_alpha) + # To help understand the logic below consider how rope embeddings work + # here it is similar, but done in groups. + # https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/3 + # First column is simply put on the main diagonal output_tensor = torch.diag(first_col) # For second column we need to swap each half groups and add minus sign diff --git a/src/peft/tuners/road/model.py b/src/peft/tuners/road/model.py index 1c19796665..aaee750727 100644 --- a/src/peft/tuners/road/model.py +++ b/src/peft/tuners/road/model.py @@ -1,4 +1,4 @@ -# Copyright 2023-present the HuggingFace Inc. team. +# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -212,9 +212,6 @@ def __getattr__(self, name: str): raise return getattr(self.model, name) - # def _check_merge_allowed(self): - # raise ValueError("Road adapters do not support merging") - def _unload_and_optionally_merge( self, merge=True, @@ -296,7 +293,7 @@ def merge_and_unload( def unload(self) -> torch.nn.Module: """ - Gets back the base model by removing all the oft modules without merging. This gives back the original base + Gets back the base model by removing all the road modules without merging. This gives back the original base model. """ return self._unload_and_optionally_merge(merge=False) From 74ac723132cdc4dda7dbda26e93defea18c86021 Mon Sep 17 00:00:00 2001 From: ppetrushkov <39625270+ppetrushkov@users.noreply.github.com> Date: Thu, 31 Jul 2025 02:00:51 +0200 Subject: [PATCH 03/10] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- src/peft/tuners/road/config.py | 8 ++++---- src/peft/tuners/road/layer.py | 1 + src/peft/tuners/road/model.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/peft/tuners/road/config.py b/src/peft/tuners/road/config.py index abaa1e5d2a..b4e1948527 100644 --- a/src/peft/tuners/road/config.py +++ b/src/peft/tuners/road/config.py @@ -31,8 +31,8 @@ class RoadVariant(str, Enum): @dataclass class RoadConfig(PeftConfig): """ - This is the configuration class to store the configuration of a [`RoadModel`]. Road adapter is proposed in - https://arxiv.org/pdf/2409.00119 . + This is the configuration class to store the configuration of a [`RoadModel`]. RoAd adapter is proposed in + https://arxiv.org/pdf/2409.00119. Args: variant (Union[`RoadVariant`, `str`]): @@ -50,7 +50,7 @@ class RoadConfig(PeftConfig): Note that model hidden size (or hidden size per partition when used with tensor parallelism) must be divisible by group_size, so for very small models you might need to reduce this parameter. init_weights (`bool`): - Whether to perform initialization of Road weights. + Whether to perform initialization of RoAd weights. target_modules (`Optional[Union[List[str], str]]`): The names of the modules to apply the adapter to. If this is specified, only the modules with the specified names will be replaced. When passing a string, a regex match will be performed. When passing a list of @@ -65,7 +65,7 @@ class RoadConfig(PeftConfig): variant: Union[str, RoadVariant] = field( default=RoadVariant.ROAD_1, - metadata={"help": ("Variant of the Road model to use. ")}, + metadata={"help": ("Variant of the Road model to use.")}, ) group_size: int = field( default=64, diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py index d05842e266..c784b0b19d 100644 --- a/src/peft/tuners/road/layer.py +++ b/src/peft/tuners/road/layer.py @@ -51,6 +51,7 @@ class RoadLayer(BaseTunerLayer): """ adapter_layer_names: tuple[str, ...] = ("road_theta", "road_alpha") + other_param_names: tuple[str, ...] = ("variant", "group_size") def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: self.base_layer = base_layer diff --git a/src/peft/tuners/road/model.py b/src/peft/tuners/road/model.py index aaee750727..bbf1fc4e16 100644 --- a/src/peft/tuners/road/model.py +++ b/src/peft/tuners/road/model.py @@ -162,7 +162,7 @@ def _create_new_module(road_config: RoadConfig, adapter_name, target, **kwargs): # no module could be matched raise ValueError( f"Target module {target} is not supported. Currently, only the following modules are supported: " - "`torch.nn.Linear` " + "`torch.nn.Linear`." ) return new_module From 8067d048e90fd6846276c203b968b4818e9beddb Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Thu, 31 Jul 2025 01:40:37 +0200 Subject: [PATCH 04/10] Change road enum, extra checks, documentation. --- docs/source/package_reference/road.md | 4 +- examples/road_finetuning/README.md | 12 +++--- examples/road_finetuning/road_finetuning.py | 14 ++++++ .../adapter_config.json | 2 +- .../training_params.json | 0 src/peft/tuners/road/bnb.py | 4 +- src/peft/tuners/road/config.py | 27 +++++++----- src/peft/tuners/road/layer.py | 43 ++++++++++++------- tests/test_common_gpu.py | 2 +- tests/test_custom_models.py | 12 +++--- tests/test_decoder_models.py | 2 +- tests/test_encoder_decoder_models.py | 2 +- tests/test_feature_extraction_models.py | 2 +- tests/test_seq_classifier.py | 2 +- 14 files changed, 78 insertions(+), 50 deletions(-) rename method_comparison/MetaMathQA/experiments/road/{llama-3.2-3B-variant2 => llama-3.2-3B-lr_0.001}/adapter_config.json (91%) rename method_comparison/MetaMathQA/experiments/road/{llama-3.2-3B-variant2 => llama-3.2-3B-lr_0.001}/training_params.json (100%) diff --git a/docs/source/package_reference/road.md b/docs/source/package_reference/road.md index e929028b29..2d2f2abbbb 100644 --- a/docs/source/package_reference/road.md +++ b/docs/source/package_reference/road.md @@ -16,9 +16,9 @@ rendered properly in your Markdown viewer. # RoAd -[RoAd](https://arxiv.org/pdf/2409.00119) is a parameter‑efficient fine‑tuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions, achieving competitive or superior performance with under 0.1% trainable parameters. Unlike LoRA’s batched low‑rank updates, RoAd’s sparse rotations reformulate to simple element‑wise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, enabling interpretable, composable task‑specific adaptations by combining orthogonal subspaces learned for different tasks. +[RoAd](https://arxiv.org/pdf/2409.00119) is a parameter‑efficient fine‑tuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions. RoAd achieves competitive or superior performance compared to other PEFT methods with under 0.1% trainable parameters. Unlike LoRA’s batched low‑rank updates, RoAd’s sparse rotations reformulate to simple element‑wise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch, i.e. serving multiple adapters simulatenously. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, interpreting its sparse 2D rotations as task-specific interventions within learned subspaces of hidden representations. These orthogonal subspaces can be composed to merge multiple task-specific behaviors—like multilingual capabilities or instruction following—without additional fine-tuning, enabling modular, interpretable adaptations in LLMs. -Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. +Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. Currently RoAd only supports linear layers. RoAd can be used on models quantized with bitsandbytes (4-bit or 8-bit). ## RoadConfig diff --git a/examples/road_finetuning/README.md b/examples/road_finetuning/README.md index e7e8b0d9df..b9ce14017c 100644 --- a/examples/road_finetuning/README.md +++ b/examples/road_finetuning/README.md @@ -2,13 +2,11 @@ ## Introduction -[RoAd](https://arxiv.org/pdf/2409.00119) a novel method -which employs a straightforward 2D rotation to adapt LLMs which is -remarkably parameter-efficient, delivering good -performance with < 0.1% trainable parameters; efficient -in serving requests requiring different adapters within a batch, with an overhead -comparable to element-wise multiplication instead of batch matrix multiplication; -enhances LLM’s interpretability. + +[RoAd](https://arxiv.org/pdf/2409.00119) is a novel method that adapts LLMs using simple 2D rotations. It is highly parameter-efficient, +achieving strong performance with less than 0.1% trainable parameters. +RoAd also supports efficient serving of mixed-adapter requests within a batch, incurring only element-wise computation overhead rather than costly batch matrix multiplications. +Additionally, it improves model interpretability through structured and composable transformations. ## Quick start ```python diff --git a/examples/road_finetuning/road_finetuning.py b/examples/road_finetuning/road_finetuning.py index 30e41a6a53..4e8718f232 100644 --- a/examples/road_finetuning/road_finetuning.py +++ b/examples/road_finetuning/road_finetuning.py @@ -1,3 +1,17 @@ +# Copyright 2025-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import torch diff --git a/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/adapter_config.json similarity index 91% rename from method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json rename to method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/adapter_config.json index 3674b02ffe..d0f74c4076 100644 --- a/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/adapter_config.json +++ b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/adapter_config.json @@ -8,5 +8,5 @@ "revision": null, "target_modules": null, "task_type": null, - "variant": "1" + "variant": "road_2" } diff --git a/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/training_params.json b/method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/training_params.json similarity index 100% rename from method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-variant2/training_params.json rename to method_comparison/MetaMathQA/experiments/road/llama-3.2-3B-lr_0.001/training_params.json diff --git a/src/peft/tuners/road/bnb.py b/src/peft/tuners/road/bnb.py index 0ea9420bf4..95e9b82b0c 100644 --- a/src/peft/tuners/road/bnb.py +++ b/src/peft/tuners/road/bnb.py @@ -35,7 +35,7 @@ def __init__( self, base_layer: torch.nn.Module, adapter_name: str, - variant: RoadVariant = RoadVariant.ROAD_1, + variant: RoadVariant = "road_1", group_size: int = 64, init_weights: bool = True, **kwargs, @@ -224,7 +224,7 @@ def __init__( self, base_layer: torch.nn.Module, adapter_name: str, - variant: RoadVariant = RoadVariant.ROAD_1, + variant: RoadVariant = "road_1", group_size: int = 64, init_weights: bool = True, **kwargs, diff --git a/src/peft/tuners/road/config.py b/src/peft/tuners/road/config.py index b4e1948527..dbf61b1127 100644 --- a/src/peft/tuners/road/config.py +++ b/src/peft/tuners/road/config.py @@ -15,17 +15,13 @@ from __future__ import annotations from dataclasses import dataclass, field -from enum import Enum -from typing import Optional, Union +from typing import Literal, Optional, Union from peft.config import PeftConfig from peft.utils import PeftType -class RoadVariant(str, Enum): - ROAD_1 = "1" - ROAD_2 = "2" - ROAD_4 = "4" +RoadVariant = Literal["road_1", "road_2", "road_4"] @dataclass @@ -36,10 +32,15 @@ class RoadConfig(PeftConfig): Args: variant (Union[`RoadVariant`, `str`]): - The variant of the Road model to use. It can be one of 1, 2, or 4. - - 1: Road-1 - - 2: Road-2 - - 4: Road-4 + The variant of the Road model to use. It can be one of road_1, road_2, or road_4. Refer to the paper + for more details. + - road_1: Uses the same scale and angle for all pairs of elements. + This variant has lowest number of parameters, it stores a number equal + to the output hidden size of parameters for each layer that RoAd is applied to. + - road_2: Uses the same scale and angle for each element. + This variant has 2x the number of parameters compared to road_1. + - road_4: Uses two different scales and angles for each ellement. + This variant has 4x the number of parameters compared to road_1. group_size (`int`): Group size defines how elements are grouped together into 2D vectors for rotation. Within each group element 0 is paired with element group_size/2, @@ -64,7 +65,7 @@ class RoadConfig(PeftConfig): """ variant: Union[str, RoadVariant] = field( - default=RoadVariant.ROAD_1, + default="road_1", metadata={"help": ("Variant of the Road model to use.")}, ) group_size: int = field( @@ -121,3 +122,7 @@ def __post_init__(self): self.target_modules = ( set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules ) + if self.variant not in ["road_1", "road_2", "road_4"]: + raise ValueError(f"Invalid variant {self.variant} specified. Please choose from road_1, road_2 or road_4") + if self.group_size <= 0 or self.group_size % 2 != 0: + raise ValueError(f"The group_size must be divisible by 2 when using RoadLayer, but got {self.group_size}.") diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py index c784b0b19d..5478118d09 100644 --- a/src/peft/tuners/road/layer.py +++ b/src/peft/tuners/road/layer.py @@ -67,10 +67,7 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, * if isinstance(base_layer, nn.Linear): in_features, out_features = base_layer.in_features, base_layer.out_features else: - in_features, out_features = None, None - warnings.warn( - f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning - ) + raise ValueError(f"Unsupported layer type '{type(base_layer)}' encountered, cannot apply RoAd adapter.") self.in_features = in_features self.out_features = out_features @@ -85,22 +82,25 @@ def update_layer( group_size, init_weights, ): - # collect the kwargs - kwargs = locals().copy() - del kwargs["self"] - self.variant[adapter_name] = variant self.group_size[adapter_name] = group_size + if self.out_features % group_size != 0: + raise ValueError( + f"The out_features of the base layer must be divisible by group_size ({group_size}) when using RoadLayer." + ) + # Actual trainable parameters - if variant == RoadVariant.ROAD_1: + if variant == "road_1": size = self.out_features // 2 - elif variant == RoadVariant.ROAD_2: + elif variant == "road_2": size = self.out_features - elif variant == RoadVariant.ROAD_4: + elif variant == "road_4": size = self.out_features * 2 else: - raise ValueError(f"Unsupported variant {variant} for RoadLayer. Supported variants are 1, 2, and 4.") + raise ValueError( + f"Unsupported variant {variant} for RoadLayer. Supported variants are road_1, road_2, and road_4." + ) self.road_theta[adapter_name] = nn.Parameter(torch.rand(size)) self.road_alpha[adapter_name] = nn.Parameter(torch.rand(size)) @@ -124,7 +124,7 @@ def __init__( self, base_layer, adapter_name: str, - variant: RoadVariant = RoadVariant.ROAD_1, + variant: RoadVariant = "road_1", group_size: int = 64, init_weights: Union[bool, str] = True, **kwargs, @@ -264,6 +264,12 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if base_layer.bias is not None: orig_bias = base_layer.bias.clone() orig_bias = torch.matmul(road_R.to(orig_dtype), orig_bias) + + if not torch.isfinite(orig_bias).all(): + raise ValueError( + f"NaNs detected in the merged bias. The adapter {active_adapter} seems to be broken" + ) + base_layer.bias.data = orig_bias.contiguous().to(orig_dtype) else: orig_weight = base_layer.weight.data @@ -334,7 +340,7 @@ def _prepare_cols( variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: # In inference mode, this can be cached - if variant == RoadVariant.ROAD_1: + if variant == "road_1": # In each group there are only group_size // 2 parameters that are reused road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten() road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten() @@ -344,14 +350,14 @@ def _prepare_cols( first_col = road_alpha * theta_cos second_col = road_alpha * theta_sin - elif variant == RoadVariant.ROAD_2: + elif variant == "road_2": # Each group has exactly group_size parameters theta_cos = road_theta.cos() theta_sin = road_theta.sin() first_col = road_alpha * theta_cos second_col = road_alpha * theta_sin - elif variant == RoadVariant.ROAD_4: + elif variant == "road_4": # Each group has 2*group_size parameters, first half used for first column, second half for second column road_theta = road_theta.reshape(-1, 2, group_size) theta_cos = road_theta[:, 0, :].cos().flatten() @@ -362,6 +368,10 @@ def _prepare_cols( first_col = alpha_1 * theta_cos second_col = alpha_2 * theta_sin + else: + raise ValueError( + f"Unsupported variant {variant} for RoadLayer. Supported variants are road_1, road_2, and road_4." + ) return first_col, second_col @@ -372,6 +382,7 @@ def _apply_road( first_col, second_col = _prepare_cols(variant, group_size, road_theta, road_alpha) # Split in half groups and join back + # See equation 4 in the RoAD paper x_grouped = x.reshape(-1, 2, group_size // 2) x1 = x_grouped[:, 0, :] x2 = x_grouped[:, 1, :] diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 6eba926ebf..6f64c89d31 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -2168,7 +2168,7 @@ def test_hra_add_new_adapter_does_not_change_device(self, mlp): def test_road_add_new_adapter_does_not_change_device(self, mlp): # same as first test, but using HRA - config = RoadConfig(target_modules=["lin0"]) + config = RoadConfig(target_modules=["lin0"], group_size=2) model = get_peft_model(mlp, config) model = model.to(self.device) model.lin0.road_theta.cpu() diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index d93654d8ae..94cbfa6e24 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -691,8 +691,8 @@ ("Vanilla MLP 2 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "group_size": 2}), ("Vanilla MLP 3 RoAd", "MLP", RoadConfig, {"target_modules": ["lin1"], "group_size": 2}), ("Vanilla MLP 4 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0", "lin1"], "group_size": 2}), - ("Vanilla MLP 5 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "2", "group_size": 2}), - ("Vanilla MLP 6 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "4", "group_size": 2}), + ("Vanilla MLP 5 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "road_2", "group_size": 2}), + ("Vanilla MLP 6 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "road_4", "group_size": 2}), ] # For this test matrix, each tuple consists of: @@ -909,15 +909,15 @@ "RoAd 2 Different", "road", RoadConfig, - {"target_modules": ["lin0"], "init_weights": False, "variant": "1", "group_size": 2}, - {"target_modules": ["lin1"], "init_weights": False, "variant": "2", "group_size": 2}, + {"target_modules": ["lin0"], "init_weights": False, "variant": "road_1", "group_size": 2}, + {"target_modules": ["lin1"], "init_weights": False, "variant": "road_2", "group_size": 2}, ), ( "RoAd 4 Different", "road", RoadConfig, - {"target_modules": ["lin0"], "init_weights": False, "variant": "1", "group_size": 2}, - {"target_modules": ["lin1"], "init_weights": False, "variant": "4", "group_size": 2}, + {"target_modules": ["lin0"], "init_weights": False, "variant": "road_1", "group_size": 2}, + {"target_modules": ["lin1"], "init_weights": False, "variant": "road_4", "group_size": 2}, ), ] diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 605625fcb8..e7063f7d03 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -186,7 +186,7 @@ RoadConfig, { "task_type": "CAUSAL_LM", - "variant": "1", + "variant": "road_1", "group_size": 2, }, ), diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 5f74d7e6a8..1a8a746d6b 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -151,7 +151,7 @@ RoadConfig, { "task_type": "SEQ_2_SEQ_LM", - "variant": "1", + "variant": "road_1", "group_size": 2, }, ), diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 787ad75b4a..0f3af095f3 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -151,7 +151,7 @@ RoadConfig, { "task_type": "FEATURE_EXTRACTION", - "variant": "1", + "variant": "road_1", "group_size": 2, }, ), diff --git a/tests/test_seq_classifier.py b/tests/test_seq_classifier.py index 0a02dc4bae..32c38828d7 100644 --- a/tests/test_seq_classifier.py +++ b/tests/test_seq_classifier.py @@ -151,7 +151,7 @@ RoadConfig, { "task_type": "SEQ_CLS", - "variant": "1", + "variant": "road_1", "group_size": 2, }, ), From 6c894d9be0858b6ffea012ac21337d52a9cd69f2 Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Sun, 3 Aug 2025 22:11:26 +0200 Subject: [PATCH 05/10] Fix initialization and unit tests --- docs/source/package_reference/road.md | 4 +- src/peft/tuners/road/__init__.py | 3 + src/peft/tuners/road/config.py | 2 +- src/peft/tuners/road/layer.py | 19 ++- src/peft/tuners/road/model.py | 45 ++++++ tests/test_custom_models.py | 192 +++++++++++++++++++------- tests/test_initialization.py | 73 ++++++++++ 7 files changed, 281 insertions(+), 57 deletions(-) diff --git a/docs/source/package_reference/road.md b/docs/source/package_reference/road.md index 2d2f2abbbb..52b9514f97 100644 --- a/docs/source/package_reference/road.md +++ b/docs/source/package_reference/road.md @@ -18,7 +18,9 @@ rendered properly in your Markdown viewer. [RoAd](https://arxiv.org/pdf/2409.00119) is a parameter‑efficient fine‑tuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions. RoAd achieves competitive or superior performance compared to other PEFT methods with under 0.1% trainable parameters. Unlike LoRA’s batched low‑rank updates, RoAd’s sparse rotations reformulate to simple element‑wise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch, i.e. serving multiple adapters simulatenously. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, interpreting its sparse 2D rotations as task-specific interventions within learned subspaces of hidden representations. These orthogonal subspaces can be composed to merge multiple task-specific behaviors—like multilingual capabilities or instruction following—without additional fine-tuning, enabling modular, interpretable adaptations in LLMs. -Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. Currently RoAd only supports linear layers. RoAd can be used on models quantized with bitsandbytes (4-bit or 8-bit). +Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. Currently RoAd only supports linear layers and it can be used on models quantized with bitsandbytes (4-bit or 8-bit). + +For running inference with different RoAd adapters in the same batch see [Inference with different LoRA adapters in the same batch](../developer_guides/lora#inference-with-different-lora-adapters-in-the-same-batch). ## RoadConfig diff --git a/src/peft/tuners/road/__init__.py b/src/peft/tuners/road/__init__.py index 7032297c41..97b2f0f54f 100644 --- a/src/peft/tuners/road/__init__.py +++ b/src/peft/tuners/road/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# Based on implementation made available in https://github.com/ppetrushkov/peft/tree/road (not from paper authors) + from peft.import_utils import is_bnb_4bit_available, is_bnb_available from peft.utils import register_peft_method diff --git a/src/peft/tuners/road/config.py b/src/peft/tuners/road/config.py index dbf61b1127..82c39c6198 100644 --- a/src/peft/tuners/road/config.py +++ b/src/peft/tuners/road/config.py @@ -47,7 +47,7 @@ class RoadConfig(PeftConfig): then element 1 is paired with element group_size/2+1 and so on. This has no effect on the model performance, since elements are unordered, however it has some effect on inference speed when used in e.g. VLLM. - For best speed group size of at least 64 is recommended. + For best speed group size of at least 32 or 64 (the default) is recommended. Note that model hidden size (or hidden size per partition when used with tensor parallelism) must be divisible by group_size, so for very small models you might need to reduce this parameter. init_weights (`bool`): diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py index 5478118d09..e6532bbbb4 100644 --- a/src/peft/tuners/road/layer.py +++ b/src/peft/tuners/road/layer.py @@ -101,18 +101,18 @@ def update_layer( raise ValueError( f"Unsupported variant {variant} for RoadLayer. Supported variants are road_1, road_2, and road_4." ) - self.road_theta[adapter_name] = nn.Parameter(torch.rand(size)) - self.road_alpha[adapter_name] = nn.Parameter(torch.rand(size)) + self.road_theta[adapter_name] = nn.Parameter(torch.empty(size)) + self.road_alpha[adapter_name] = nn.Parameter(torch.empty(size)) - # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed - if init_weights: - self.reset_road_parameters(adapter_name, init_weights) + self.reset_parameters(adapter_name, init_weights) self._move_adapter_to_device_of_base_layer(adapter_name) self.set_adapter(self.active_adapters) - def reset_road_parameters(self, adapter_name, init_weights): + def reset_parameters(self, adapter_name, init_weights): if init_weights is False: + nn.init.normal_(self.road_theta[adapter_name].data, mean=0.0, std=0.5) + nn.init.normal_(self.road_alpha[adapter_name].data, mean=1.0, std=0.5) return nn.init.zeros_(self.road_theta[adapter_name].data) nn.init.ones_(self.road_alpha[adapter_name].data) @@ -154,7 +154,14 @@ def _check_forward_args(self, x, *args, **kwargs): ) raise ValueError(msg) + if self.merged: + # It is unclear what would be the right thing to do if users pass adapter_names and there are merged + # adapters. Therefore, it is better to raise an error in this case. + msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." + raise ValueError(msg) + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) diff --git a/src/peft/tuners/road/model.py b/src/peft/tuners/road/model.py index bbf1fc4e16..c19aa8c39f 100644 --- a/src/peft/tuners/road/model.py +++ b/src/peft/tuners/road/model.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +from contextlib import contextmanager +from functools import partial import operator from typing import Optional @@ -37,6 +39,11 @@ from .layer import RoadLayer, dispatch_default +def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): + # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference + kwargs["adapter_names"] = adapter_names + return args, kwargs + class RoadModel(BaseTuner): """ """ @@ -212,6 +219,44 @@ def __getattr__(self, name: str): raise return getattr(self.model, name) + @contextmanager + def _enable_peft_forward_hooks(self, *args, **kwargs): + # If adapter_names is passed as an argument, we inject it into the forward arguments. + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is None: + # nothing to do + yield + return + + if self.training: + raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + + # Check that users only passed actually existing adapters. + # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want + # to check that there is at least one layer with the given name, or else something like typos can easily slip. + expected_adapters = set() + for layer in self.modules(): + if isinstance(layer, RoadLayer): + expected_adapters |= layer.road_theta.keys() + unique_adapters = {name for name in adapter_names if name != "__base__"} + unexpected_adapters = unique_adapters - expected_adapters + if unexpected_adapters: + raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") + + hook_handles = [] + for module in self.modules(): + if isinstance(module, RoadLayer): + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + # TODO LoRA also has hooks for beam search, ignore this for now + + yield + + for handle in hook_handles: + handle.remove() + def _unload_and_optionally_merge( self, merge=True, diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 94cbfa6e24..be7aac1dc9 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -4544,18 +4544,27 @@ def test_requires_grad_fourierft_same_targets(self): "base_model.model.lin0.fourierft_spectrum.adapter1", ) +MIXED_ADAPTER_TEST_CASES = [ + ( + "LoRA mixed adapter", + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin0"], r=16, init_lora_weights=False), + ), + ( + "RoAd mixed adapter", + RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False), + RoadConfig(target_modules=["lin0"], group_size=2, variant="road_2", init_weights=False), + ), +] class TestMixedAdapterBatches: torch_device = infer_device() - @pytest.fixture - def mlp_lora(self): + def get_mlp_peft(self, config0, config1): """A simple MLP with 2 LoRA adapters""" torch.manual_seed(0) base_model = MLP().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["lin0"], r=16, init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) return peft_model @@ -4594,32 +4603,59 @@ def run_checks(self, model, inputs): assert torch.allclose(output0[1::3], output_mixed[1::3]) assert torch.allclose(output1[2::3], output_mixed[2::3]) - def test_mixed_adapter_batches_lora_mlp(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) + def test_mixed_adapter_batches_lora_mlp(self, test_name, config0, config1): + mlp_peft = self.get_mlp_peft(config0, config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} - self.run_checks(mlp_lora, inputs) + self.run_checks(mlp_peft, inputs) - def test_mixed_adapter_batches_lora_different_target_layers(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with different target layers", + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + "RoAd mixed adapter with different target layers", + RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False), + RoadConfig(target_modules=["lin1"], group_size=2, init_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_different_target_layers(self, test_name, config0, config1): base_model = MLP().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["lin1"], init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with modules to save", + LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False), + LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False), + ), + ( + "RoAd mixed adapter with modules to save", + RoadConfig(target_modules=["lin0"], modules_to_save=["lin1"], group_size=2, init_weights=False), + RoadConfig(target_modules=["lin0"], modules_to_save=["lin1"], group_size=2, init_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, test_name, config0, config1): base_model = MLP().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with unsupported layer", + LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False), + LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, test_name, config0, config1): base_model = MLPWithGRU().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} @@ -4630,50 +4666,80 @@ def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, mlp_lora): ): self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with overlapping layers", + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin0", "lin1"], init_lora_weights=False), + ), + ( + "RoAd mixed adapter with overlapping layers", + RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False), + RoadConfig(target_modules=["lin0", "lin1"], group_size=2, init_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, test_name, config0, config1): base_model = MLP().to(self.torch_device).eval() # target different lora layers - config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["lin0", "lin1"], init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_conv1d_emb(self): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with conv1d", + LoraConfig(target_modules=["emb", "conv1d"], init_lora_weights=False), + LoraConfig(target_modules=["emb", "conv1d"], r=16, init_lora_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_conv1d_emb(self, test_name, config0, config1): base_model = ModelEmbConv1D().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["emb", "conv1d"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["emb", "conv1d"], r=16, init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_conv1d_emb_multiple_modules_to_save(self): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with conv1d and emb and modules to save", + LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_conv1d_emb_multiple_modules_to_save(self, test_name, config0, config1): base_model = ModelEmbConv1D().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_conv2d(self): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with conv2d", + LoraConfig(target_modules=["conv2d"], init_lora_weights=False), + LoraConfig(target_modules=["conv2d"], r=16, init_lora_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_conv2d(self, test_name, config0, config1): base_model = ModelConv2D().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["conv2d"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["conv2d"], r=16, init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) inputs = {"X": torch.arange(270).view(6, 5, 3, 3).to(self.torch_device)} self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_mha_raises(self): + @pytest.mark.parametrize("test_name, config0, config1", [ + ( + "LoRA mixed adapter with mha", + LoraConfig(target_modules=["mha"], init_lora_weights=False), + LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False), + ), + ]) + def test_mixed_adapter_batches_mha_raises(self, test_name, config0, config1): base_model = ModelMha().to(self.torch_device).eval() - config0 = LoraConfig(target_modules=["mha"], init_lora_weights=False) - config1 = LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) @@ -4682,56 +4748,73 @@ def test_mixed_adapter_batches_mha_raises(self): with pytest.raises(TypeError, match=msg): self.run_checks(peft_model, inputs) - def test_mixed_adapter_batches_lora_length_mismatch_raises(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) + def test_mixed_adapter_batches_lora_length_mismatch_raises(self, test_name, config0, config1): + mlp_peft = self.get_mlp_peft(config0, config1) inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), "adapter_names": ["__base__"] * 5, # wrong length! } msg = r"Length of `adapter_names` should be the same as the number of inputs, but got " with pytest.raises(ValueError, match=msg): - mlp_lora.forward(**inputs) + mlp_peft.forward(**inputs) - def test_mixed_adapter_batches_lora_training_mode_raises(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) + def test_mixed_adapter_batches_lora_training_mode_raises(self, test_name, config0, config1): + mlp_peft = self.get_mlp_peft(config0, config1) inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), "adapter_names": ["__base__"] * 9, } - mlp_lora = mlp_lora.train() + mlp_peft = mlp_peft.train() msg = r"Cannot pass `adapter_names` when the model is in training mode." with pytest.raises(ValueError, match=msg): - mlp_lora.forward(**inputs) + mlp_peft.forward(**inputs) - def test_mixed_adapter_batches_lora_disabled(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) + def test_mixed_adapter_batches_lora_disabled(self, test_name, config0, config1): # Disabling adapters should have precedence over passing adapter names + mlp_peft = self.get_mlp_peft(config0, config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} - with mlp_lora.disable_adapter(): - output_disabled = mlp_lora(**inputs) + with mlp_peft.disable_adapter(): + output_disabled = mlp_peft(**inputs) adapters = ["__base__", "adapter0", "adapter1"] inputs["adapter_names"] = [adapters[i % 3] for i in (range(len(inputs["X"])))] - with mlp_lora.disable_adapter(): - output_mixed = mlp_lora.forward(**inputs) + with mlp_peft.disable_adapter(): + output_mixed = mlp_peft.forward(**inputs) assert torch.allclose(output_disabled, output_mixed) - def test_mixed_adapter_batches_lora_merged_raises(self, mlp_lora): + @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) + def test_mixed_adapter_batches_lora_merged_raises(self, test_name, config0, config1): # When there are merged adapters, passing adapter names should raise an error + mlp_peft = self.get_mlp_peft(config0, config1) inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), "adapter_names": ["adapter0"] * 9, } - mlp_lora.merge_adapter(["adapter0"]) + mlp_peft.merge_adapter(["adapter0"]) msg = r"Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." with pytest.raises(ValueError, match=msg): - mlp_lora.forward(**inputs) + mlp_peft.forward(**inputs) - def test_mixed_adapter_batches_lora_wrong_adapter_name_raises(self): + @pytest.mark.parametrize("test_name, config", [ + ( + "LoRA mixed batch wrong adapter name", + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + ), + ( + "LoRA mixed batch wrong adapter name", + RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False), + ), + ]) + def test_mixed_adapter_batches_lora_wrong_adapter_name_raises(self, test_name, config): # Ensure that all of the adapter names that are being passed actually exist torch.manual_seed(0) x = torch.arange(90).view(-1, 10).to(self.torch_device) base_model = MLP().to(self.torch_device).eval() - config = LoraConfig(target_modules=["lin0"], init_lora_weights=False) peft_model = get_peft_model(base_model, config).eval() peft_model.add_adapter(adapter_name="other", peft_config=config) @@ -4786,8 +4869,22 @@ def test_mixed_adapter_batches_lora_with_dora_but_dora_not_included_works(self): } peft_model.forward(**inputs) + @pytest.mark.parametrize("test_name, config0, config1, factor", [ + ( + "LoRA mixed adapter timing", + LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False), + LoraConfig(task_type="CAUSAL_LM", r=16, init_lora_weights=False), + 2.0, + ), + ( + "RoAd mixed adapter timing", + RoadConfig(task_type="CAUSAL_LM", init_weights=False), + RoadConfig(task_type="CAUSAL_LM", variant="road_2", init_weights=False), + 3.0, + ), + ]) @require_non_cpu - def test_mixed_adapter_batches_lora_opt_timing(self): + def test_mixed_adapter_batches_lora_opt_timing(self, test_name, config0, config1, factor): # Use a more realistic model (opt-125m) and do a simple runtime check to ensure that mixed adapter batches # don't add too much overhead. These types of tests are inherently flaky, so we try to add in some robustness. logs = [] # store the time it takes to run each forward pass here @@ -4804,7 +4901,6 @@ def timed(): with timed(): output_base = base_model(**inputs).logits - config0 = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) peft_model = get_peft_model(base_model, config0, "adapter1").eval() with timed(): output0 = peft_model(**inputs).logits @@ -4812,7 +4908,6 @@ def timed(): # sanity check, outputs are not the same assert not torch.allclose(output_base, output0) - config1 = LoraConfig(task_type="CAUSAL_LM", r=16, init_lora_weights=False) peft_model.add_adapter("adapter2", config1) peft_model.set_adapter("adapter2") with timed(): @@ -4844,7 +4939,6 @@ def timed(): time_non_mixed = (time_base + time0 + time1) / 3 time_mixed = min(time_mixed) - factor = 2.0 assert time_mixed < factor * time_non_mixed # Measure timing of running base and adapter separately vs using a mixed batch. Note that on CPU, the diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 2b83ed21f3..1cc75965b1 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -52,6 +52,7 @@ PeftModelForTokenClassification, PrefixTuningConfig, PromptTuningConfig, + RoadConfig, VBLoRAConfig, VeraConfig, get_eva_state_dict, @@ -1756,6 +1757,78 @@ def test_c3a_with_incompatible_block_size_with_out_features(self): with pytest.raises(ValueError, match=msg): get_peft_model(model, config) +class TestRoadInitialization: + torch_device = infer_device() + + def get_model(self): + class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 30, bias=bias) + self.lin1 = nn.Linear(30, 2, bias=bias) + + def forward(self, X): + X = self.lin0(X) + X = self.lin1(X) + return X + + return MLP().to(self.torch_device) + + def get_conv2d_model(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + # choose a large weight so that averages are close to expected values + self.linear = nn.Linear(1000, 1000) + self.embed = nn.Embedding(1000, 1000) + self.conv2d = nn.Conv2d(100, 100, 3) + + def forward(self, x): + x_int = (100 * x).int() + x_4d = x.flatten().reshape(1, 100, 10, 10) + return self.linear(x), self.embed(x_int), self.conv2d(x_4d) + + return MyModule().eval().to(self.torch_device) + + def test_road_default_initialization(self): + torch.manual_seed(0) + model = self.get_model() + config = RoadConfig(target_modules=["lin0"], group_size=2) + model = get_peft_model(model, config) + weight_alpha = model.lin0.road_alpha["default"].data + weight_theta = model.lin0.road_theta["default"].data + torch.allclose(weight_alpha, torch.ones_like(weight_alpha)) + torch.allclose(weight_theta, torch.zeros_like(weight_theta)) + + def test_road_with_odd_group_size(self): + group_size = 3 # odd values are not allowed + msg = f"The group_size must be divisible by 2 when using RoadLayer, but got {group_size}." + with pytest.raises(ValueError, match=re.escape(msg)): + RoadConfig(group_size=group_size) + + def test_road_with_too_large_group_size(self): + group_size = 64 # larger than out_features + msg = f"The out_features of the base layer must be divisible by group_size ({group_size}) when using RoadLayer." + model = self.get_model() + config = RoadConfig(target_modules=["lin0"], group_size=group_size) + with pytest.raises(ValueError, match=re.escape(msg)): + get_peft_model(model, config) + + def test_road_with_incompatible_group_size_with_out_features(self): + group_size = 4 # even, but 30 does not divide by 4 + model = self.get_model() + config = RoadConfig(target_modules=["lin0"], group_size=group_size) + msg = f"The out_features of the base layer must be divisible by group_size ({group_size}) when using RoadLayer." + with pytest.raises(ValueError, match=re.escape(msg)): + get_peft_model(model, config) + + def test_road_with_conv2d_layer(self): + model = self.get_conv2d_model() + config = RoadConfig(target_modules=["conv2d"], group_size=2) + msg = f"Target module Conv2d(100, 100, kernel_size=(3, 3), stride=(1, 1)) is not supported. Currently, only the following modules are supported: `torch.nn.Linear`." + with pytest.raises(ValueError, match=re.escape(msg)): + get_peft_model(model, config) + class TestNoInfiniteRecursionDeepspeed: # see #1892 for details From 06a84b90f9d91c8862c0465b618745c3a31120f9 Mon Sep 17 00:00:00 2001 From: ppetrushkov <39625270+ppetrushkov@users.noreply.github.com> Date: Wed, 6 Aug 2025 01:49:48 +0200 Subject: [PATCH 06/10] Apply suggestions from code review Co-authored-by: Benjamin Bossan --- tests/test_custom_models.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index db85d73900..fa596f2d66 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -4729,7 +4729,7 @@ def run_checks(self, model, inputs): assert torch.allclose(output1[2::3], output_mixed[2::3]) @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) - def test_mixed_adapter_batches_lora_mlp(self, test_name, config0, config1): + def test_mixed_adapter_batches_mlp(self, test_name, config0, config1): mlp_peft = self.get_mlp_peft(config0, config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} self.run_checks(mlp_peft, inputs) @@ -4749,7 +4749,7 @@ def test_mixed_adapter_batches_lora_mlp(self, test_name, config0, config1): ), ], ) - def test_mixed_adapter_batches_lora_different_target_layers(self, test_name, config0, config1): + def test_mixed_adapter_batches_different_target_layers(self, test_name, config0, config1): base_model = MLP().to(self.torch_device).eval() peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) @@ -4771,7 +4771,7 @@ def test_mixed_adapter_batches_lora_different_target_layers(self, test_name, con ), ], ) - def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, test_name, config0, config1): + def test_mixed_adapter_batches_multiple_modules_to_save(self, test_name, config0, config1): base_model = MLP().to(self.torch_device).eval() peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) @@ -4788,7 +4788,7 @@ def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, test_name, co ), ], ) - def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, test_name, config0, config1): + def test_mixed_adapter_batches_unsupported_layer_raises(self, test_name, config0, config1): base_model = MLPWithGRU().to(self.torch_device).eval() peft_model = get_peft_model(base_model, config0, "adapter0").eval() peft_model.add_adapter("adapter1", config1) @@ -4815,7 +4815,7 @@ def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, test_name, co ), ], ) - def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, test_name, config0, config1): + def test_mixed_adapter_batches_partly_overlapping_target_layers(self, test_name, config0, config1): base_model = MLP().to(self.torch_device).eval() # target different lora layers peft_model = get_peft_model(base_model, config0, "adapter0").eval() @@ -4898,7 +4898,7 @@ def test_mixed_adapter_batches_mha_raises(self, test_name, config0, config1): self.run_checks(peft_model, inputs) @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) - def test_mixed_adapter_batches_lora_length_mismatch_raises(self, test_name, config0, config1): + def test_mixed_adapter_batches_length_mismatch_raises(self, test_name, config0, config1): mlp_peft = self.get_mlp_peft(config0, config1) inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), @@ -4909,7 +4909,7 @@ def test_mixed_adapter_batches_lora_length_mismatch_raises(self, test_name, conf mlp_peft.forward(**inputs) @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) - def test_mixed_adapter_batches_lora_training_mode_raises(self, test_name, config0, config1): + def test_mixed_adapter_batches_training_mode_raises(self, test_name, config0, config1): mlp_peft = self.get_mlp_peft(config0, config1) inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), @@ -4921,7 +4921,7 @@ def test_mixed_adapter_batches_lora_training_mode_raises(self, test_name, config mlp_peft.forward(**inputs) @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) - def test_mixed_adapter_batches_lora_disabled(self, test_name, config0, config1): + def test_mixed_adapter_batches_disabled(self, test_name, config0, config1): # Disabling adapters should have precedence over passing adapter names mlp_peft = self.get_mlp_peft(config0, config1) inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)} @@ -4936,7 +4936,7 @@ def test_mixed_adapter_batches_lora_disabled(self, test_name, config0, config1): assert torch.allclose(output_disabled, output_mixed) @pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES) - def test_mixed_adapter_batches_lora_merged_raises(self, test_name, config0, config1): + def test_mixed_adapter_batches_merged_raises(self, test_name, config0, config1): # When there are merged adapters, passing adapter names should raise an error mlp_peft = self.get_mlp_peft(config0, config1) inputs = { @@ -4956,7 +4956,7 @@ def test_mixed_adapter_batches_lora_merged_raises(self, test_name, config0, conf LoraConfig(target_modules=["lin0"], init_lora_weights=False), ), ( - "LoRA mixed batch wrong adapter name", + "RoAD mixed batch wrong adapter name", RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False), ), ], From 070ca710aa309729d9d42be573f41915489170af Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Wed, 6 Aug 2025 04:02:55 +0200 Subject: [PATCH 07/10] Add gpu tests for bnb road --- examples/road_finetuning/road_finetuning.py | 2 +- tests/test_custom_models.py | 1 + tests/test_gpu_examples.py | 222 ++++++++++++++++++++ 3 files changed, 224 insertions(+), 1 deletion(-) diff --git a/examples/road_finetuning/road_finetuning.py b/examples/road_finetuning/road_finetuning.py index 4e8718f232..830f1da97f 100644 --- a/examples/road_finetuning/road_finetuning.py +++ b/examples/road_finetuning/road_finetuning.py @@ -169,7 +169,7 @@ def tokenize_function(examples): parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") parser.add_argument("--save_step", type=int, default=100, help="Save step interval") parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training") - parser.add_argument("--variant", type=str, default="1", choices=["1", "2", "4"], help="RoAD variant") + parser.add_argument("--variant", type=str, default="road_1", choices=["road_1", "road_2", "road_4"], help="RoAD variant") parser.add_argument( "--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for RoAd" ) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index fa596f2d66..96c2b6862e 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -4668,6 +4668,7 @@ def test_requires_grad_fourierft_same_targets(self): ) +# this is for PEFT methods that support mixed adapter batches. MIXED_ADAPTER_TEST_CASES = [ ( "LoRA mixed adapter", diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index af915d68d6..c29570e0ea 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -64,6 +64,7 @@ PrefixTuningConfig, PromptEncoderConfig, RandLoraConfig, + RoadConfig, TaskType, VeraConfig, get_peft_model, @@ -1717,6 +1718,227 @@ def test_causal_lm_training_multi_gpu_4bit_randlora(self): # assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None + @pytest.mark.single_gpu_tests + def test_causal_lm_training_8bit_road(self): + r""" + Same as test_causal_lm_training but with RoAd + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = RoadConfig( + variant="road_1", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=1e-3, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_4bit_road(self): + r""" + Same as test_causal_lm_training_4bit but with RoAd + """ + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + device_map="auto", + ) + + tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + model = prepare_model_for_kbit_training(model) + + config = RoadConfig( + variant="road_1", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=1e-3, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + def test_causal_lm_training_multi_gpu_8bit_road(self): + r""" + Same as test_causal_lm_training_multi_gpu but with RoAd + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map=DEVICE_MAP_MAP[self.causal_lm_model_id], + quantization_config=BitsAndBytesConfig(load_in_8bit=True), + ) + + assert set(model.hf_device_map.values()) == set(range(device_count)) + assert {p.device.index for p in model.parameters()} == set(range(device_count)) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = RoadConfig( + variant="road_1", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=1e-3, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + def test_causal_lm_training_multi_gpu_4bit_road(self): + r""" + Same as test_causal_lm_training_multi_gpu_4bit but with RoAd + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map=DEVICE_MAP_MAP[self.causal_lm_model_id], + quantization_config=BitsAndBytesConfig(load_in_4bit=True), + ) + + assert set(model.hf_device_map.values()) == set(range(device_count)) + assert {p.device.index for p in model.parameters()} == set(range(device_count)) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = RoadConfig( + variant="road_1", + target_modules=["q_proj", "v_proj"], + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=1e-3, + fp16=True, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.single_gpu_tests def test_causal_lm_training_lora_resize_embeddings_trainable_tokens(self): r""" From 319ab6d26c8e02166b8712341306c2f13e48d6d3 Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Wed, 6 Aug 2025 11:48:59 +0200 Subject: [PATCH 08/10] Style --- examples/road_finetuning/road_finetuning.py | 4 +++- tests/test_gpu_examples.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/road_finetuning/road_finetuning.py b/examples/road_finetuning/road_finetuning.py index 830f1da97f..0469785db4 100644 --- a/examples/road_finetuning/road_finetuning.py +++ b/examples/road_finetuning/road_finetuning.py @@ -169,7 +169,9 @@ def tokenize_function(examples): parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") parser.add_argument("--save_step", type=int, default=100, help="Save step interval") parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training") - parser.add_argument("--variant", type=str, default="road_1", choices=["road_1", "road_2", "road_4"], help="RoAD variant") + parser.add_argument( + "--variant", type=str, default="road_1", choices=["road_1", "road_2", "road_4"], help="RoAD variant" + ) parser.add_argument( "--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for RoAd" ) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index c29570e0ea..c2b5a37da4 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -1938,7 +1938,6 @@ def test_causal_lm_training_multi_gpu_4bit_road(self): # assert loss is not None assert trainer.state.log_history[-1]["train_loss"] is not None - @pytest.mark.single_gpu_tests def test_causal_lm_training_lora_resize_embeddings_trainable_tokens(self): r""" From 93b71bc94d7e1683050dd8110a47cb4d5902bafa Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Mon, 18 Aug 2025 18:43:58 +0200 Subject: [PATCH 09/10] Style with doc-builder --- src/peft/tuners/road/config.py | 20 +++++++++----------- src/peft/tuners/road/layer.py | 8 ++++---- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/peft/tuners/road/config.py b/src/peft/tuners/road/config.py index 82c39c6198..50125786c5 100644 --- a/src/peft/tuners/road/config.py +++ b/src/peft/tuners/road/config.py @@ -32,23 +32,21 @@ class RoadConfig(PeftConfig): Args: variant (Union[`RoadVariant`, `str`]): - The variant of the Road model to use. It can be one of road_1, road_2, or road_4. Refer to the paper - for more details. + The variant of the Road model to use. It can be one of road_1, road_2, or road_4. Refer to the paper for + more details. - road_1: Uses the same scale and angle for all pairs of elements. - This variant has lowest number of parameters, it stores a number equal - to the output hidden size of parameters for each layer that RoAd is applied to. + This variant has lowest number of parameters, it stores a number equal to the output hidden size of + parameters for each layer that RoAd is applied to. - road_2: Uses the same scale and angle for each element. This variant has 2x the number of parameters compared to road_1. - road_4: Uses two different scales and angles for each ellement. This variant has 4x the number of parameters compared to road_1. group_size (`int`): - Group size defines how elements are grouped together into 2D vectors for rotation. - Within each group element 0 is paired with element group_size/2, - then element 1 is paired with element group_size/2+1 and so on. - This has no effect on the model performance, since elements are unordered, - however it has some effect on inference speed when used in e.g. VLLM. - For best speed group size of at least 32 or 64 (the default) is recommended. - Note that model hidden size (or hidden size per partition when used with tensor parallelism) + Group size defines how elements are grouped together into 2D vectors for rotation. Within each group + element 0 is paired with element group_size/2, then element 1 is paired with element group_size/2+1 and so + on. This has no effect on the model performance, since elements are unordered, however it has some effect + on inference speed when used in e.g. VLLM. For best speed group size of at least 32 or 64 (the default) is + recommended. Note that model hidden size (or hidden size per partition when used with tensor parallelism) must be divisible by group_size, so for very small models you might need to reduce this parameter. init_weights (`bool`): Whether to perform initialization of RoAd weights. diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py index 1d3839d9e3..d30aed3078 100644 --- a/src/peft/tuners/road/layer.py +++ b/src/peft/tuners/road/layer.py @@ -30,15 +30,15 @@ class RoadLayer(BaseTunerLayer): Generally the idea of RoAD is to split the input vector into many 2D vectors and rotate each 2D vector with its own 2D rotation matrix. For additional flexibility, each rotation matrix is multiplied by a trainable scale. - when applied to vector R @ x each pair of elements of x is transformed like this: - y₀ = x₀ * α * cosθ - xₙ * α * sinθ - yₙ = x₀ * α * sinθ + xₙ * α * cosθ + when applied to vector R @ x each pair of elements of x is transformed like this: `y₀ = x₀ * α * cosθ - xₙ * α * + sinθ` and `yₙ = x₀ * α * sinθ + xₙ * α * cosθ` The scales α and angles θ are learned for each pair of elements and, moreover, each of the 4 instances in the rotation matrix may actually be different (when using variant 2 or 4). Note that instead of using two consecutive elements x₀ x₁ we first split the whole vector into groups and pair - elements from the first with the second half of the same group, which allows for more efficient inference implementation. + elements from the first with the second half of the same group, which allows for more efficient inference + implementation. The adapter needs to only store the angles θ and scales α, rather than the full matrix R and the inference implementation only needs to do elementwise vector multiplications. From e4ce021298a5a69adfc436d1058d01b684b009ab Mon Sep 17 00:00:00 2001 From: Pavel Petrushkov Date: Tue, 19 Aug 2025 11:49:25 +0200 Subject: [PATCH 10/10] Add to toc --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 72d2d55199..85439fd042 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -132,6 +132,8 @@ title: C3A - local: package_reference/miss title: MiSS + - local: package_reference/road + title: RoAd title: Adapters - sections: