-
Notifications
You must be signed in to change notification settings - Fork 2k
Open
Description
System Info
In peft, when creating a LoRAConfig
with layers_to_transform=[0, ...]
if the modules have less than two nestings before the nn.Module
(e.g. my_module.0.self_attention.query
) the module isn't changed even if specified by target_modules
.
Possible solution:
in check_target_module_exists
in peft/tuners/tuners_utils.py
change to:
if is_using_layer_indexes and target_module_found:
layer_index = None
# TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave
# For now, empty layers_pattern means any layer pattern is ok
if layers_pattern is None or len(layers_pattern) == 0:
# layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) # previous line
layer_index = re.match(r".*\.(\d+)\.", key)
else:
layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern
for pattern in layers_pattern:
# layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) # previous line
layer_index = re.match(rf"{pattern}\.(\d+)\.", key)
if layer_index is not None:
break
Who can help?
Reproduction
LoraConfig with:
layers_to_transform=<non-empty-list>,
on an nn.Module
where the nn.ModuleList
is directly nested (and not at least two times nested).
i.e. this would not work: my_module.0.self_attention.query
and this would work my_module.my_nested_module.0.self_attention.query
Expected behavior
Even when the model has only one nesting the layers should be adapted according to layers_to_transform
.
Metadata
Metadata
Assignees
Labels
No labels