diff --git a/src/peft/tuners/oft/config.py b/src/peft/tuners/oft/config.py index 4b33f13cb3..9c62e1bece 100644 --- a/src/peft/tuners/oft/config.py +++ b/src/peft/tuners/oft/config.py @@ -14,9 +14,12 @@ from __future__ import annotations +import warnings from dataclasses import dataclass, field from typing import Literal, Optional, Union +import packaging.version + from peft.config import PeftConfig from peft.utils import PeftType @@ -193,4 +196,18 @@ def check_kwargs(cls, **kwargs): "with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions. " "Alternatively, downgrade PEFT to version 0.13.0 to use the old adapter weights." ) + if kwargs.get("use_cayley_neumann", False): + peft_version = kwargs.get("peft_version", "0.0.0") # if not present, set a low dummy version + # remove commit hash, if present + peft_version = peft_version.partition("@")[0] + parsed_version = packaging.version.Version(peft_version) + min_version = packaging.version.Version("0.18.0") + # note: config.peft_version was added in 0.18.0, so if it's missing, it means we're below min version + if parsed_version < min_version: + msg = ( + "The cayley-neumann parameterization has been slightly changed to be more numerically stable in " + "PEFT 0.18.0. Please retrain your adapter weights with newer PEFT versions. Alternatively, " + "downgrade PEFT to version 0.17.0 to use the old parameterization." + ) + warnings.warn(msg) return super().check_kwargs(**kwargs) diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 6b14d015ae..e18f17f38e 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -97,7 +97,9 @@ def __init__( self.use_cayley_neumann = use_cayley_neumann self.num_cayley_neumann_terms = num_cayley_neumann_terms # Create indices for upper triangle (excluding diagonal) - self.rows, self.cols = torch.triu_indices(block_size, block_size, 1) + rows, cols = torch.triu_indices(block_size, block_size, 1) + self.register_buffer("rows", rows, persistent=False) + self.register_buffer("cols", cols, persistent=False) def _pytorch_skew_symmetric(self, vec, block_size): batch_size = vec.shape[0] @@ -139,9 +141,11 @@ def _cayley_batch( R.add_(Q_squared, alpha=2.0) Q_power = Q_squared - for i in range(3, num_neumann_terms): + for _ in range(3, num_neumann_terms - 1): Q_power = torch.bmm(Q_power, Q_skew) R.add_(Q_power, alpha=2.0) + Q_power = torch.bmm(Q_power, Q_skew) + R.add_(Q_power) else: id_mat = ( torch.eye(Q_skew.shape[-1], device=Q_skew.device) @@ -621,9 +625,13 @@ def unmerge(self) -> None: if active_adapter in self.oft_R.keys(): oft_mat = self.get_delta_weight(active_adapter) + previous_dtype = oft_mat.dtype + if previous_dtype != torch.float32: + oft_mat = oft_mat.to(torch.float32) + orig_weights = self.get_base_layer().weight.data orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype)) + orig_weights = torch.mm(torch.linalg.inv(oft_mat).to(previous_dtype), orig_weights.to(previous_dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) base_layer.weight.data = orig_weights.to(orig_dtype) @@ -855,13 +863,17 @@ def unmerge(self) -> None: if active_adapter in self.oft_R.keys(): oft_mat = self.get_delta_weight(active_adapter) + previous_dtype = oft_mat.dtype + if previous_dtype != torch.float32: + oft_mat = oft_mat.to(torch.float32) + orig_weights = self.get_base_layer().weight.data.clone() orig_weights = orig_weights.view( self.out_features, self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0], ) orig_weights = torch.transpose(orig_weights, 0, 1) - orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype)) + orig_weights = torch.mm(torch.linalg.inv(oft_mat).to(previous_dtype), orig_weights.to(previous_dtype)) orig_weights = torch.transpose(orig_weights, 0, 1) orig_weights = orig_weights.view( self.out_features,