diff --git a/src/peft/tuners/loha/config.py b/src/peft/tuners/loha/config.py index 00eacb1301..79c1f63013 100644 --- a/src/peft/tuners/loha/config.py +++ b/src/peft/tuners/loha/config.py @@ -35,7 +35,8 @@ class LoHaConfig(LycorisConfig): module_dropout (`float`): The dropout probability for disabling LoHa modules during training. use_effective_conv2d (`bool`): - Use parameter effective decomposition for Conv2d with ksize > 1 ("Proposition 3" from FedPara paper). + Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 ("Proposition 3" from FedPara + paper). target_modules (`Optional[Union[List[str], str]]`): The names of the modules to apply the adapter to. If this is specified, only the modules with the specified names will be replaced. When passing a string, a regex match will be performed. When passing a list of @@ -79,7 +80,10 @@ class LoHaConfig(LycorisConfig): use_effective_conv2d: bool = field( default=False, metadata={ - "help": 'Use parameter effective decomposition for Conv2d 3x3 with ksize > 1 ("Proposition 3" from FedPara paper)' + "help": ( + "Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 " + '("Proposition 3" from FedPara paper)' + ) }, ) target_modules: Optional[Union[list[str], str]] = field( diff --git a/src/peft/tuners/loha/layer.py b/src/peft/tuners/loha/layer.py index 876e34f826..19582bec80 100644 --- a/src/peft/tuners/loha/layer.py +++ b/src/peft/tuners/loha/layer.py @@ -45,7 +45,7 @@ def _available_adapters(self) -> set[str]: def create_adapter_parameters(self, adapter_name: str, r: int, shape: tuple[int, ...]): # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L130C9-L143C75 - if len(shape) == 4: + if len(shape) == 4: # Conv2d self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode @@ -53,7 +53,15 @@ def create_adapter_parameters(self, adapter_name: str, r: int, shape: tuple[int, self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode - else: + elif len(shape) == 3: # Conv1d + self.hada_t1[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], 1)) + self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode + self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode + + self.hada_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], 1)) + self.hada_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0])) # out_dim, 1-mode + self.hada_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) # in_dim , 2-mode + else: # Linear self.hada_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0], r)) self.hada_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1])) @@ -127,6 +135,11 @@ def update_layer( if isinstance(base_layer, nn.Linear): shape = tuple(base_layer.weight.shape) elif isinstance(base_layer, nn.Conv2d): + # For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead. + # Since 1x1 convolutions are essentially pointwise operations (matrix multiplications), + # they can be more efficiently handled with the flattened weight representation, + # similar to how Linear layers work. This optimization reduces computational cost + # without affecting the mathematical equivalence of the operation. use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) if use_effective_conv2d: shape = (base_layer.out_channels, base_layer.in_channels, *base_layer.kernel_size) @@ -135,6 +148,19 @@ def update_layer( base_layer.out_channels, base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1], ) + elif isinstance(base_layer, nn.Conv1d): + # For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons + # as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent + # to a Linear layer applied across the channel dimension. Using flattened representation + # avoids unnecessary reshaping and improves computational efficiency. + use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size[0] != 1 + if use_effective_conv2d: + shape = (base_layer.out_channels, base_layer.in_channels, base_layer.kernel_size[0]) + else: + shape = ( + base_layer.out_channels, + base_layer.in_channels * base_layer.kernel_size[0], + ) else: raise TypeError(f"LoHa is not implemented for base layers of type {type(base_layer).__name__}") @@ -173,6 +199,8 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: ) base_layer = self.get_base_layer() + + # Reshape to match base layer shape weight = weight.reshape(base_layer.weight.shape) # Perform rank dropout during training - drop rows of addition weights @@ -292,6 +320,50 @@ def __repr__(self) -> str: return "loha." + rep +class Conv1d(LoHaLayer): + """LoHa implemented in Conv1d layer""" + + def __init__( + self, + base_layer: nn.Module, + adapter_name: str = "default", + r: int = 0, + alpha: float = 0.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + use_effective_conv2d: bool = False, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer( + adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs + ) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + input = self._cast_input_dtype(input, delta_weight.dtype) + # don't add bias here, because the bias is already included in the output of the base_layer + base_layer = self.get_base_layer() + return F.conv1d( + input, + delta_weight, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + + def __repr__(self) -> str: + rep = super().__repr__() + return "loha." + rep + + # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L9 diff --git a/src/peft/tuners/loha/model.py b/src/peft/tuners/loha/model.py index c13fabc45a..2647972122 100644 --- a/src/peft/tuners/loha/model.py +++ b/src/peft/tuners/loha/model.py @@ -21,7 +21,7 @@ from peft.utils import TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING from peft.utils.other import get_pattern_key -from .layer import Conv2d, Linear, LoHaLayer +from .layer import Conv1d, Conv2d, Linear, LoHaLayer class LoHaModel(LycorisTuner): @@ -85,6 +85,7 @@ class LoHaModel(LycorisTuner): prefix: str = "hada_" layers_mapping: dict[type[torch.nn.Module], type[LoHaLayer]] = { torch.nn.Conv2d: Conv2d, + torch.nn.Conv1d: Conv1d, torch.nn.Linear: Linear, } diff --git a/src/peft/tuners/lokr/config.py b/src/peft/tuners/lokr/config.py index ea8a0e837f..6d25dc5c12 100644 --- a/src/peft/tuners/lokr/config.py +++ b/src/peft/tuners/lokr/config.py @@ -35,7 +35,8 @@ class LoKrConfig(LycorisConfig): module_dropout (`float`): The dropout probability for disabling LoKr modules during training. use_effective_conv2d (`bool`): - Use parameter effective decomposition for Conv2d with ksize > 1 ("Proposition 3" from FedPara paper). + Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 ("Proposition 3" from FedPara + paper). decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix. decompose_factor (`int`): @@ -85,7 +86,10 @@ class LoKrConfig(LycorisConfig): use_effective_conv2d: bool = field( default=False, metadata={ - "help": 'Use parameter effective decomposition for Conv2d 3x3 with ksize > 1 ("Proposition 3" from FedPara paper)' + "help": ( + "Use parameter effective decomposition for Conv2d (and Conv1d) with ksize > 1 " + '("Proposition 3" from FedPara paper)' + ) }, ) decompose_both: bool = field( diff --git a/src/peft/tuners/lokr/layer.py b/src/peft/tuners/lokr/layer.py index 195e5cd4c8..c898065cee 100644 --- a/src/peft/tuners/lokr/layer.py +++ b/src/peft/tuners/lokr/layer.py @@ -75,8 +75,8 @@ def create_adapter_parameters( self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r)) self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0])) - if len(shape) == 4: - # Conv2d + # Handle both Conv2d and Conv1d + if len(shape) == 4: # Conv2d if use_w2: self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:])) elif use_effective_conv2d: @@ -86,6 +86,18 @@ def create_adapter_parameters( else: self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3])) + elif len(shape) == 3: # Conv1d + if use_w2: + self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], shape[2])) + elif use_effective_conv2d: # Even for Conv1d, use the effective parameter for kernel dimension + # We pass (r, r, kernel_size, 1) in order to be compatible with the 2d assumptions made + # in make_weight_cp (only relevant for the effective conv2d case). + self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], 1)) + self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode + self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode + else: + self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) + self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2])) else: # Linear if use_w2: @@ -201,7 +213,27 @@ def update_layer( use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 + # For 1x1 convolutions, disable effective_conv2d to avoid unnecessary tensor reshaping overhead. + # Since 1x1 convolutions are essentially pointwise operations (matrix multiplications), + # they can be more efficiently handled with the flattened weight representation, + # similar to how Linear layers work. This optimization reduces computational cost + # without affecting the mathematical equivalence of the operation. use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) + elif isinstance(base_layer, nn.Conv1d): + in_dim, out_dim = base_layer.in_channels, base_layer.out_channels + k_size = (base_layer.kernel_size[0],) # Convert to a tuple with single element + + in_m, in_n = factorization(in_dim, decompose_factor) + out_l, out_k = factorization(out_dim, decompose_factor) + shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), k) + + use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) + use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 + # For Conv1d with kernel_size=1, disable effective_conv2d for the same optimization reasons + # as 1x1 Conv2d. Kernel size 1 means no spatial/temporal context, making it equivalent + # to a Linear layer applied across the channel dimension. Using flattened representation + # avoids unnecessary reshaping and improves computational efficiency. + use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size[0] != 1 else: raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}") @@ -237,7 +269,12 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor: # Make weights with Kronecker product weight = make_kron(w1, w2, self.scaling[adapter_name]) - weight = weight.reshape(self.get_base_layer().weight.shape) + + # Get base layer for reshaping + base_layer = self.get_base_layer() + + # Regular reshape to match base layer shape + weight = weight.reshape(base_layer.weight.shape) # Perform rank dropout during training - drop rows of addition weights rank_dropout = self.rank_dropout[adapter_name] @@ -358,6 +395,52 @@ def __repr__(self) -> str: return "lokr." + rep +class Conv1d(LoKrLayer): + """LoKr implemented in Conv1d layer""" + + def __init__( + self, + base_layer: nn.Module, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + adapter_name: str = "default", + r: int = 0, + alpha: float = 0.0, + rank_dropout: float = 0.0, + module_dropout: float = 0.0, + use_effective_conv2d: bool = False, + init_weights: bool = True, + **kwargs, + ): + super().__init__(base_layer) + + # Create adapter and set it active + self._active_adapter = adapter_name + self.update_layer( + adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs + ) + + def _get_delta_activations( + self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any + ) -> torch.Tensor: + delta_weight = self.get_delta_weight(adapter_name) + input = self._cast_input_dtype(input, delta_weight.dtype) + # don't add bias here, because the bias is already included in the output of the base_layer + base_layer = self.get_base_layer() + return F.conv1d( + input, + delta_weight, + stride=base_layer.stride, + padding=base_layer.padding, + dilation=base_layer.dilation, + groups=base_layer.groups, + ) + + def __repr__(self) -> str: + rep = super().__repr__() + return "lokr." + rep + + # Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11 diff --git a/src/peft/tuners/lokr/model.py b/src/peft/tuners/lokr/model.py index dc0d5bec65..18e0ef64c8 100644 --- a/src/peft/tuners/lokr/model.py +++ b/src/peft/tuners/lokr/model.py @@ -21,7 +21,7 @@ from peft.utils import TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING from peft.utils.other import get_pattern_key -from .layer import Conv2d, Linear, LoKrLayer +from .layer import Conv1d, Conv2d, Linear, LoKrLayer class LoKrModel(LycorisTuner): @@ -86,6 +86,7 @@ class LoKrModel(LycorisTuner): prefix: str = "lokr_" layers_mapping: dict[type[torch.nn.Module], type[LoKrLayer]] = { torch.nn.Conv2d: Conv2d, + torch.nn.Conv1d: Conv1d, torch.nn.Linear: Linear, } diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 5e6c308d90..7276b2e92e 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -257,7 +257,7 @@ def _create_new_module(cls, config: LycorisConfig, adapter_name: str, target: nn else: target_base_layer = target - if isinstance(target_base_layer, torch.nn.Conv2d): + if isinstance(target_base_layer, (torch.nn.Conv2d, torch.nn.Conv1d)): new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs) elif isinstance(target_base_layer, torch.nn.Linear): new_module = new_module_cls(target, adapter_name=adapter_name, **kwargs) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index a22600138d..393ed2c23c 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -267,6 +267,22 @@ ("Conv2d 2 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"]}), ("Conv2d 3 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}), ("Conv2d 4 LOHA", "Conv2d", LoHaConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}), + ("Conv1d LOHA", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}), + ("Conv1d LOHA 1", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"]}), + ("Conv1d LOHA 2", "Conv1d", LoHaConfig, {"target_modules": ["conv1d"], "r": 2}), + ( + "Conv1d LOHA 3", + "Conv1dBigger", + LoHaConfig, + {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": True}, + ), + ( + "Conv1d LOHA 4", + "Conv1dBigger", + LoHaConfig, + {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False}, + ), + ("Conv2d 1x1 LOHA", "Conv2d1x1", LoHaConfig, {"target_modules": ["conv2d"]}), # LoKr ("Vanilla MLP 1 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0"}), ("Vanilla MLP 2 LOKR", "MLP", LoKrConfig, {"target_modules": ["lin0"]}), @@ -285,10 +301,25 @@ ), ("Vanilla MLP 7 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "rank_dropout": 0.5}), ("Vanilla MLP 8 LOKR", "MLP", LoKrConfig, {"target_modules": "lin0", "decompose_both": True, "r": 1, "alpha": 1}), + ("Conv1d LOKR 1", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"]}), + ("Conv1d LOKR 2", "Conv1d", LoKrConfig, {"target_modules": ["conv1d"], "r": 2}), + ( + "Conv1d LOKR 3", + "Conv1dBigger", + LoKrConfig, + {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": True}, + ), + ( + "Conv1d LOKR 4", + "Conv1dBigger", + LoKrConfig, + {"target_modules": ["conv1d"], "r": 2, "use_effective_conv2d": False}, + ), ("Conv2d 1 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"]}), ("Conv2d 2 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"]}), ("Conv2d 3 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d"], "use_effective_conv2d": True}), ("Conv2d 4 LOKR", "Conv2d", LoKrConfig, {"target_modules": ["conv2d", "lin0"], "use_effective_conv2d": True}), + ("Conv2d 1x1 LOKR", "Conv2d1x1", LoKrConfig, {"target_modules": ["conv2d"]}), ( "Conv2d 5 LOKR", "Conv2d", @@ -1189,6 +1220,28 @@ def forward(self, X): return X +class ModelConv1DBigger(nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(64, 16, 2) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin0 = nn.Linear(144, 2) + self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float + + def forward(self, X): + X = X.to(self.dtype) + X = X.reshape(-1, 1, 10) + X = torch.concat([X] * 64, dim=1) + X = self.conv1d(X) + X = self.relu(X) + X = self.flat(X) + X = self.lin0(X) + X = self.sm(X) + return X + + class ModelConv2D(nn.Module): def __init__(self, bias=True): super().__init__() @@ -1234,6 +1287,27 @@ def forward(self, X): return X +class ModelConv2D1x1(nn.Module): + def __init__(self): + super().__init__() + self.conv2d = nn.Conv2d(1, 10, kernel_size=(1, 1), padding=0) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin0 = nn.Linear(10 * 3 * 3, 2) + self.sm = nn.LogSoftmax(dim=-1) + self.dtype = torch.float + + def forward(self, X): + X = X.to(self.dtype) + X = X.reshape(-1, 1, 3, 3) + X = self.conv2d(X) + X = self.relu(X) + X = self.flat(X) + X = self.lin0(X) + X = self.sm(X) + return X + + class ModelConv2DGroups(nn.Module): def __init__(self): super().__init__() @@ -1282,6 +1356,25 @@ def forward(self, X): return X +class ModelConv1DKernel1(nn.Module): + def __init__(self): + super().__init__() + self.conv1d = nn.Conv1d(in_channels=3, out_channels=10, kernel_size=1) + self.relu = nn.ReLU() + self.flat = nn.Flatten() + self.lin0 = nn.Linear(10 * 10, 2) + self.dtype = torch.float + + def forward(self, x): + x = x.to(self.dtype) + x = x.reshape(-1, 3, 10) # batch, channels, seq_len + x = self.conv1d(x) + x = self.relu(x) + x = self.flat(x) + x = self.lin0(x) + return x + + class ModelConv3D(nn.Module): def __init__(self): super().__init__() @@ -1382,9 +1475,18 @@ def from_pretrained(cls, model_id, torch_dtype=None): if model_id == "Conv1d": return ModelConv1D().to(torch_dtype) + if model_id == "Conv1dBigger": + return ModelConv1DBigger().to(torch_dtype) + if model_id == "Conv2d": return ModelConv2D().to(torch_dtype) + if model_id == "Conv2d1x1": + return ModelConv2D1x1().to(torch_dtype) + + if model_id == "Conv1dKernel1": + return ModelConv1DKernel1().to(torch_dtype) + if model_id == "Conv2dGroups": return ModelConv2DGroups().to(torch_dtype)