这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@

from peft.tuners.lora.variants import get_alora_offsets_for_forward, get_alora_offsets_for_generate
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils import AuxiliaryTrainingWrapper
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.integrations import init_empty_weights
from peft.utils.other import create_attention_mask, set_additional_trainable_modules
from peft.utils.other import TrainableTokensWrapper, create_attention_mask, set_additional_trainable_modules

from . import __version__
from .config import PeftConfig
Expand Down Expand Up @@ -3047,7 +3048,11 @@ def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]:

layer_status: list[TunerLayerStatus] = []
for name, module in base_model.named_modules():
if not isinstance(module, BaseTunerLayer):
if not isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)):
continue
if isinstance(module, TrainableTokensWrapper):
# Skip TrainableTokensWrapper, since it wraps TrainableTokensLayer, which is the actual PEFT layer we're
# interested in.
continue

# determine if all submodules/parameters if this module require grad or not
Expand Down
26 changes: 26 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):

"""

# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ()
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str, ...] = ()
# List all merged adapters
merged_adapters: list[str] = []

def __init__(self, module_to_save, adapter_name, **kwargs):
"""Extra kwargs will be passed to `self.init_modules` and `self.update`."""
super().__init__()
Expand All @@ -255,6 +262,10 @@ def init_modules(self, adapter_name, **kwargs):
"""A place to initialize PyTorch modules in `__init__` before the call to `self.update()`."""
raise NotImplementedError

def _get_available_adapters(self) -> set[str]:
"""Return all adapter names that can be found on this module."""
raise NotImplementedError

def _error_message_name(self):
"""Returns a user friendly identifier for error messages, e.g. for type compatibility error messages from
`check_module()` so that the user can backtrack where the error comes from. A generic "training wrapper" is
Expand Down Expand Up @@ -492,6 +503,9 @@ def unload_and_optionally_merge_module(
class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
"""Wraps a module that is supposed to be trained (i.e. `requires_grad_(True)`) and saved after training."""

# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)

def __init__(self, module_to_save, adapter_name):
super().__init__(module_to_save, adapter_name)

Expand Down Expand Up @@ -700,6 +714,10 @@ def unload_and_optionally_merge_module(

return new_module

def _get_available_adapters(self) -> set[str]:
"""Return all adapter names that can be found on this module."""
return set(self.modules_to_save.keys())


class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
"""Wraps a module (typically an embedding layer) that is supposed to be re-trained selectively (i.e.
Expand All @@ -709,6 +727,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
`TrainableTokensLayer`.
"""

# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("token_adapter.trainable_tokens_delta",)
other_param_names: tuple[str, ...] = ("token_adapter.token_indices", "token_adapter.trainable_tokens_original")

def __init__(
self,
module_to_save: torch.nn.Module,
Expand Down Expand Up @@ -871,6 +893,10 @@ def unload_and_optionally_merge_module(
self.token_adapter.merge(safe_merge=safe_merge, adapter_names=adapter_names)
return self.token_adapter.get_base_layer()

def _get_available_adapters(self) -> set[str]:
"""Return all adapter names that can be found on this module."""
return set(self.token_adapter.trainable_tokens_delta.keys())


def _get_input_embeddings_name(model, default=None):
if not hasattr(model, "get_input_embeddings"):
Expand Down
180 changes: 178 additions & 2 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,15 +606,29 @@ class TestModelAndLayerStatus:
torch_device = infer_device()

@pytest.fixture
def small_model(self):
def small_base_model_cls(self):
class SmallModel(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 10)
self.lin1 = nn.Linear(10, 10)

return SmallModel

@pytest.fixture
def small_base_emb_model_cls(self):
class SmallEmbModel(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 10)
self.emb = nn.Embedding(10, 10)

return SmallEmbModel

@pytest.fixture
def small_model(self, small_base_model_cls):
config = LoraConfig(target_modules="lin0")
return get_peft_model(SmallModel(), config)
return get_peft_model(small_base_model_cls(), config)

@pytest.fixture
def large_model(self):
Expand Down Expand Up @@ -801,6 +815,44 @@ def test_devices_all_cpu_large(self, large_model):
]
assert result == expected

def test_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
layer_status = model.get_layer_status()

assert len(layer_status) == 2
status = layer_status[1] # for modules_to_save

assert status.name == "model.lin1"
assert status.module_type == "ModulesToSaveWrapper"
assert status.enabled is True
assert status.active_adapters == ["default"]
assert status.merged_adapters == []
assert status.available_adapters == ["default"]
assert status.requires_grad == {"default": True}
assert status.devices == {"default": ["cpu"]}

def test_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
layer_status = model.get_layer_status()

assert len(layer_status) == 2
status = layer_status[1] # for trainable tokens

assert status.name == "model.emb.token_adapter"
assert status.module_type == "TrainableTokensLayer"
assert status.enabled is True
assert status.active_adapters == ["default"]
assert status.merged_adapters == []
assert status.available_adapters == ["default"]
assert status.requires_grad == {"default": True}
assert status.devices == {"default": ["cpu"]}

@require_non_cpu
def test_devices_all_gpu_large(self, large_model):
large_model.to(self.torch_device)
Expand Down Expand Up @@ -932,6 +984,32 @@ def test_model_enabled_irregular(self, large_model):
model_status = large_model.get_model_status()
assert model_status.enabled == "irregular"

def test_model_enabled_irregular_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)

# disable only lin0
model.lin0.enable_adapters(False)

model_status = model.get_model_status()
# since lin1 is still enabled, the overall model status is "irregular"
assert model_status.enabled == "irregular"

def test_model_enabled_irregular_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)

# disable only lin0
model.lin0.enable_adapters(False)

model_status = model.get_model_status()
# since emb is still enabled, the overall model status is "irregular"
assert model_status.enabled == "irregular"

def test_model_active_adapters_small(self, small_model):
model_status = small_model.get_model_status()
assert model_status.active_adapters == ["default"]
Expand All @@ -958,6 +1036,34 @@ def test_model_active_adapters_irregular(self, large_model):
model_status = large_model.get_model_status()
assert model_status.active_adapters == "irregular"

def test_model_active_adapters_with_modules_to_save_irregular(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
model.add_adapter("other", config)

# switch modules_to_save to "other"
model.lin1.set_adapter("other")

model_status = model.get_model_status()
# since lin0 is still on "default", the overall model status is "irregular"
assert model_status.active_adapters == "irregular"

def test_model_active_adapters_with_trainable_tokens_irregular(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
model.add_adapter("other", config)

# switch trainable tokens to "other"
model.emb.set_adapter("other")

model_status = model.get_model_status()
# since lin0 is still on "default", the overall model status is "irregular"
assert model_status.active_adapters == "irregular"

def test_model_merged_adapters_small(self, small_model):
model_status = small_model.get_model_status()
assert model_status.merged_adapters == []
Expand Down Expand Up @@ -1021,6 +1127,32 @@ def test_model_requires_grad_model_irregular(self, large_model):
model_status = large_model.get_model_status()
assert model_status.requires_grad == {"default": "irregular", "other": False}

def test_model_requires_irregular_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)

# set modules_to_save to requires_grad=False
model.lin1.modules_to_save.default.weight.requires_grad = False

model_status = model.get_model_status()
# since lin1 is still requires_grad=True, the overall model status is "irregular"
assert model_status.requires_grad == {"default": "irregular"}

def test_model_requires_irregular_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)

# set trainable tokens to requires_grad=False
model.emb.token_adapter.trainable_tokens_delta.default.requires_grad = False

model_status = model.get_model_status()
# since emb is still requires_grad=True, the overall model status is "irregular"
assert model_status.requires_grad == {"default": "irregular"}

def test_model_available_adapters_small(self, small_model):
model_status = small_model.get_model_status()
assert model_status.available_adapters == ["default"]
Expand Down Expand Up @@ -1075,6 +1207,50 @@ def test_model_target_parameters_and_target_modules(self, large_model):
assert model_status.num_adapter_layers == 2
assert model_status.trainable_params == 2 * (8 * 10 + 10 * 8)

def test_model_status_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
num_base_params = sum(p.numel() for p in small_base_model_cls().parameters())
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
model_status = model.get_model_status()

assert model_status.base_model_type == "SmallModel"
assert model_status.adapter_model_type == "LoraModel"
assert model_status.peft_types == {"default": "LORA"}
# 2 x 80 for LoRA, 100 for modules_to_save.weight, 10 for modules_to_save.bias
assert model_status.trainable_params == 2 * 80 + 100 + 10
assert model_status.total_params == 2 * 80 + 100 + 10 + num_base_params
assert model_status.num_adapter_layers == 2 # lin0 + lin1
assert model_status.enabled is True
assert model_status.active_adapters == ["default"]
assert model_status.merged_adapters == []
assert model_status.requires_grad == {"default": True}
assert model_status.available_adapters == ["default"]
assert model_status.devices == {"default": ["cpu"]} # all on CPU

def test_model_status_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
num_base_params = sum(p.numel() for p in small_base_emb_model_cls().parameters())
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
model_status = model.get_model_status()

assert model_status.base_model_type == "SmallEmbModel"
assert model_status.adapter_model_type == "LoraModel"
assert model_status.peft_types == {"default": "LORA"}
# 2 x 80 for LoRA, 3 x 10 for trainable tokens
assert model_status.trainable_params == 2 * 80 + 3 * 10
assert model_status.total_params == 2 * 80 + 3 * 10 + num_base_params
assert model_status.num_adapter_layers == 2
assert model_status.enabled is True
assert model_status.active_adapters == ["default"]
assert model_status.merged_adapters == []
assert model_status.requires_grad == {"default": True}
assert model_status.available_adapters == ["default"]
assert model_status.devices == {"default": ["cpu"]} # all on CPU

def test_loha_model(self):
# ensure that this also works with non-LoRA, it's not necessary to test all tuners
class SmallModel(nn.Module):
Expand Down
5 changes: 3 additions & 2 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tempfile
import warnings
from dataclasses import replace
from operator import attrgetter

import pytest
import torch
Expand Down Expand Up @@ -1453,7 +1454,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
target, "other_param_names", []
)
for attr in attributes_to_check:
assert adapter_to_delete not in getattr(target, attr)
assert adapter_to_delete not in attrgetter(attr)(target)

# check auxiliary modules
for module in model.modules():
Expand Down Expand Up @@ -1527,7 +1528,7 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
target, "other_param_names", []
)
for attr in attributes_to_check:
assert adapter_to_delete not in getattr(target, attr)
assert adapter_to_delete not in attrgetter(attr)(target)

# check auxiliary modules
for module in model.modules():
Expand Down
Loading