这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,14 +871,16 @@ def generate_suffixes(s):
# Initialize a set for required suffixes
required_suffixes = set()

for item, suffixes in target_modules_suffix_map.items():
# We sort the target_modules_suffix_map simply to get deterministic behavior, since sets have no order. In theory
# the order should not matter but in case there is a bug, it's better for the bug to be deterministic.
for item, suffixes in sorted(target_modules_suffix_map.items(), key=lambda tup: tup[1]):
# Go through target_modules items, shortest suffixes first
for suffix in suffixes:
# If the suffix is already in required_suffixes or matches other_module_names, skip it
if suffix in required_suffixes or suffix in other_module_suffixes:
continue
# Check if adding this suffix covers the item
if not any(item.endswith(req_suffix) for req_suffix in required_suffixes):
if not any(item.endswith("." + req_suffix) for req_suffix in required_suffixes):
required_suffixes.add(suffix)
break

Expand Down
45 changes: 45 additions & 0 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,3 +1282,48 @@ def test_get_peft_model_applies_find_target_modules(self):
# check that the resulting model is still the same
model_check_after = sum(p.sum() for p in model.parameters())
assert model_check_sum_before == model_check_after

def test_suffix_is_substring_of_other_suffix(self):
# This test is based on a real world bug found in diffusers. The issue was that we needed the suffix
# 'time_emb_proj' in the minimal target modules. However, if there already was the suffix 'proj' in the
# required_suffixes, 'time_emb_proj' would not be added because the test was `endswith(suffix)` and
# 'time_emb_proj' ends with 'proj'. The correct logic is to test if `endswith("." + suffix")`. The module names
# chosen here are only a subset of the hundreds of actual module names but this subset is sufficient to
# replicate the bug.
target_modules = [
"down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj",
"mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj",
"up_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj",
"mid_block.attentions.0.proj_out",
"up_blocks.0.attentions.0.proj_out",
"down_blocks.1.attentions.0.proj_out",
"up_blocks.0.resnets.0.time_emb_proj",
"down_blocks.0.resnets.0.time_emb_proj",
"mid_block.resnets.0.time_emb_proj",
]
other_module_names = [
"conv_in",
"time_proj",
"time_embedding",
"time_embedding.linear_1",
"add_time_proj",
"add_embedding",
"add_embedding.linear_1",
"add_embedding.linear_2",
"down_blocks",
"down_blocks.0",
"down_blocks.0.resnets",
"down_blocks.0.resnets.0",
"up_blocks",
"up_blocks.0",
"up_blocks.0.attentions",
"up_blocks.0.attentions.0",
"up_blocks.0.attentions.0.norm",
"up_blocks.0.attentions.0.transformer_blocks",
"up_blocks.0.attentions.0.transformer_blocks.0",
"up_blocks.0.attentions.0.transformer_blocks.0.norm1",
"up_blocks.0.attentions.0.transformer_blocks.0.attn1",
]
expected = {"time_emb_proj", "proj", "proj_out"}
result = find_minimal_target_modules(target_modules, other_module_names)
assert result == expected