-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement ensure_weight_tying for trainable_token_indices (#2864) #2870
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for a lot for handling the update of weight tying of trainable tokens. What's there already looks quite good, but I wonder if we can simplify the implementation, please check my suggestions.
Regarding the tests, I wanted to map the tests you wrote onto the table from #2864, this is what I ended up with:
| weights tied | ensure_weight_tying | LoraConfig trainable_token_indices | result | test |
|---|---|---|---|---|
| False | False | [1, 2, 3] |
trainable tokens on embeddings only | |
| False | True | [1, 2, 3] |
warn & trainable tokens on embeddings only | test_ensure_weight_tying_warns_when_model_not_tied_list_format |
| True | False | [1, 2, 3] |
tied trainable tokens | |
| True | True | [1, 2, 3] |
tied trainable tokens | test_ensure_weight_tying_with_single_layer |
| False | False | {"lm_head": [1,2], "embed_tokens": [1,2]} |
treat as separate | |
| False | True | {"lm_head": [1,2], "embed_tokens": [1,2]} |
warn & treat as separate | |
| True | False | {"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens | test_weight_tying_bc_same_indices_applied |
| True | True | {"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens | test_ensure_weight_tying_applied_with_same_indices |
| False | False | {"lm_head": [1,2], "embed_tokens": [3,4]} |
treat as separate | |
| False | True | {"lm_head": [1,2], "embed_tokens": [3,4]} |
warn & treat as separate | |
| True | False | {"lm_head": [1,2], "embed_tokens": [3,4]} |
*treat as separate | test_weight_tying_bc_different_indices_treated_separately |
| True | True | {"lm_head": [1,2], "embed_tokens": [3,4]} |
*error | test_ensure_weight_tying_errors_with_different_indices |
Does this look right to you? I think it means there are still a few gaps in the tests, could you please provide the missing ones? Some tests could be combined via pytest.mark.parametrize if the expected outcomes are the same.
tests/test_trainable_tokens.py
Outdated
| ] | ||
| assert warnings_found | ||
|
|
||
| def test_ensure_weight_tying_warns_when_model_not_tied_dict_format(self, model_weight_untied, recwarn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test can be merged with test_ensure_weight_tying_warns_when_model_not_tied_list_format by parametrizing the trainable_token_indices argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
tests/test_trainable_tokens.py
Outdated
| warnings_list = [w.message.args[0] for w in recwarn] | ||
| warnings_found = [ | ||
| msg for msg in warnings_list if "ensure_weight_tying=True but the model does not have tied weights" in msg | ||
| ] | ||
| assert warnings_found |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bit more elegant to do:
expected = ...
assert any(expected in msg for msg in warings_list)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
tests/test_trainable_tokens.py
Outdated
| ensure_weight_tying=True, | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError) as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use:
msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
src/peft/utils/other.py
Outdated
| ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False) | ||
|
|
||
| # Check if we're dealing with dict format that specifies both embed_tokens and lm_head | ||
| is_dict_format = isinstance(peft_config.trainable_token_indices, dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need is_dict_format. The check below, len(target_layers) > 1, is already enough, is it not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes , re-reviewed this , and simplified the logic significantly . refrence 232c6e7 for implementation .
src/peft/utils/other.py
Outdated
| if "embed" in key_lower and not ("lm" in key_lower or "head" in key_lower): | ||
| embed_key = key | ||
| elif "lm_head" in key_lower or ("head" in key_lower and "lm" not in key_lower): | ||
| lm_head_key = key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we overcomplicate things here. If there are multiple target_layers, can we not just compare them to the tied weights? Is it important to identify here which one is for the embedding and which one is for the LM head?
Below, you're using the names for the error message, which is a nice touch, but if we can refrain from guessing here, it would be worth it to make the error message more generic IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i relooked at this and removed the string matching logic (checking for "embed", "lm_head", etc.) and now directly compare target layers against model._tied_weights_keys and the actual embedding layer. The error message is now generic, showing all conflicting tied layers instead of assuming specific names.
src/peft/utils/other.py
Outdated
| indices_mismatch = True | ||
| else: | ||
| # Same indices - if weights are tied and we're applying tying, skip lm_head (it'll be tied later) | ||
| if weights_tied and not (not ensure_weight_tying and False): # Will apply tying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check makes no sense to me, why and False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved in 232c6e7
|
About the test coverage , the table looks correct. I've filled all 6 gaps in the test coverage:
|
|
@BenjaminBossan Thank you for the detailed review . i have made all the changes and would appreciate if you could have a look at it again . I'll make any changes necessary . |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on the PR and extending the tests. I still have a few comments, please check.
As a general remark, the logic for handling weight tying in trainable tokens is inherently quite complex. Therefore, I focused on checking if the implementation is clear and simple while keeping the functionality intact. When I found code that I thought could be improved in this regard, I added a comment. But I would also kindly ask you to double check if you can find anything that can be simplified and apply it, even if I haven't commented on it. This will help with the long term health of the PEFT code base 🙏
src/peft/utils/other.py
Outdated
| weights_tied = ( | ||
| model_config.get("tie_word_embeddings", False) | ||
| # some models may be misconfigured to have weight tying enabled but don't define tied weights keys | ||
| and model._tied_weights_keys is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could theoretically raise an AttributeError if used with a non-HF transformers model, right? It's not so likely in practice, since a non-HF transformers model is unlikely to have a model config with tie_word_embeddings, but let's still use getattr here to be safe. I would also assign this to a variable, as it's used 3 times in total.
src/peft/utils/other.py
Outdated
| # Check if any of the target layers correspond to tied weights in the model | ||
| # Instead of guessing layer names, compare against actual tied weight keys | ||
| # Extract module names from tied weights keys (remove the weight attribute name) | ||
| tied_module_names = {".".join(key.split(".")[:-1]) for key in model._tied_weights_keys} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say this is simpler:
| tied_module_names = {".".join(key.split(".")[:-1]) for key in model._tied_weights_keys} | |
| tied_module_names = {key.rpartition(".")[0] for key in model._tied_weights_keys} |
I saw that the existing code does the same thing as you did here, but let's still try to improve :) (feel free to adjust the existing code below too).
src/peft/utils/other.py
Outdated
| break | ||
|
|
||
| # Find which target layers are in the tied weights (including the embedding source) | ||
| for target_layer in target_layers: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename target_layer to target_layer_name to make it clear that it's the name, not the module itself.
src/peft/utils/other.py
Outdated
| has_both_layers = True | ||
| # Check if all tied layers have the same indices | ||
| first_indices = target_layers[tied_layer_keys[0]] | ||
| indices_match = all(target_layers[key] == first_indices for key in tied_layer_keys[1:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need both indices_match and indices_mismatch, it's a bit redundant. I think it's easiest to eliminate the former.
src/peft/utils/other.py
Outdated
| for name, module in model.named_modules(): | ||
| if module is embedding_module: | ||
| # Get just the last part of the name for matching with target_layers | ||
| embedding_name = name.split(".")[-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although the logic in this loop is fine, it can be a bit confusing: What would it mean if the embedding_module is not found? This should never happen, right? So I'm wondering if we can do something like:
embedding_name = next(n.split(".")[-1] for n, m in model.named_modules() if m is embedding_module)This would raise an error if embedding_module is not found instead of leaving embedding_name = None. What's your opinion?
src/peft/utils/other.py
Outdated
| if weights_tied and ensure_weight_tying and has_both_layers and indices_mismatch: | ||
| # Build more generic error message showing the conflicting layers | ||
| tied_layers_info = ", ".join([f"{key}: {target_layers[key]}" for key in tied_layer_keys]) | ||
| raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not raise this error immediately after indices_mismatch was determined? The earlier we can raise, the better. It should also make the check simpler, as we only need to check for if indices_mismatch.
src/peft/utils/other.py
Outdated
| # Since indices match here, indices_mismatch=False, so this simplifies to: we apply tying | ||
| # Skip all tied modules except the embedding (first one in tied_layer_keys) | ||
| # The embedding is typically first, but to be safe, skip modules in _tied_weights_keys | ||
| for key in tied_layer_keys: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we cannot simply take the intersection between the two:
layers_to_skip = set(tied_layer_keys) & tied_module_names.
This approach would fail if we have a substring match but not a full string match, which is what you cover with tied_module.endswith(key). However, I don't see what would need to happen for a substring-only match, and AFAICT, the tests also never reach that point. Could you please explain?
src/peft/utils/other.py
Outdated
| and isinstance(model.get_input_embeddings(), TrainableTokensWrapper) | ||
| ): | ||
| # the embedding layer is modified and we want weight tying. | ||
| and not (not ensure_weight_tying and has_both_layers and indices_mismatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conditional is a bit hard to read IMO, let's try to simplify. So for and not (not ensure_weight_tying ... let's move it out of the parenthesis, i.e. it becomes and ensure_weight_tying. As for indices_mismatch, this can only ever be True if has_both_layers is also True, right? So we don't really need to check both.
src/peft/utils/other.py
Outdated
|
|
||
| if len(target_layers) > 1 and weights_tied and model._tied_weights_keys: | ||
| # Check if any of the target layers correspond to tied weights in the model | ||
| # Instead of guessing layer names, compare against actual tied weight keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment can be removed IMO.
tests/test_trainable_tokens.py
Outdated
| assert lm_head_adapter.token_indices["default"] == [1, 2] | ||
|
|
||
| def test_weight_tying_bc_same_indices_applied(self, model_weight_tied): | ||
| """Backwards compatibility: same indices should have weight tying even when ensure_weight_tying=False""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really for BC, is it? I think this is just the general expected behavior. The BC part is only for cases where the behavior might not be what the users expects but we cannot change it now because it would be backwards incompatible.
Implement ensure_weight_tying for trainable_token_indices
Summary
This PR implements consistent weight tying behavior for
trainable_token_indicesas specified in issue #2864. It extends theensure_weight_tyingparameter (introduced in PR #2803) to work withtrainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.Fixes #2864 (trainable_token_indices portion)
Problem Statement
Background
PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when
tie_word_embeddings=True). Theensure_weight_tyingparameter was introduced in PR #2803 to give users explicit control over this behavior formodules_to_save. However, the same control was missing fortrainable_token_indices.The Issue
Issue identified that the weight tying behavior for
trainable_token_indiceswas not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:Solution Approach
Implementation Strategy:
Changes Made
1. Updated Configuration Documentation
File:
src/peft/tuners/lora/config.pyUpdated the
ensure_weight_tyingparameter docstring to clarify that it now applies to bothmodules_to_saveandtrainable_token_indices, making the documentation consistent with the implementation.2. Implemented Weight Tying Logic
File:
src/peft/utils/other.pyAdded comprehensive logic within the existing
trainable_token_indiceshandling block:Key Components:
ensure_weight_tying=FalseFour Cases Implemented:
Case 1 - Warning for Untied Models:
weights_tied=False+ensure_weight_tying=TrueCase 2 - Error for Contradictory Configuration:
weights_tied=True+ensure_weight_tying=True+ different indicesCase 3 - Backwards Compatibility:
weights_tied=True+ensure_weight_tying=False+ different indicesCase 4 - Apply Tying:
3. Comprehensive Test Suite
File:
tests/test_trainable_tokens.pyAdded 7 new test methods covering all scenarios:
Test Coverage:
test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list formattest_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict formattest_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibilitytest_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory configtest_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indicestest_weight_tying_bc_same_indices_applied: Verifies BC for same indicestest_ensure_weight_tying_with_single_layer: Verifies list format tyingTesting Results
New Tests
All 7 new tests pass successfully:
test_ensure_weight_tying_warns_when_model_not_tied_list_formattest_ensure_weight_tying_warns_when_model_not_tied_dict_formattest_weight_tying_bc_different_indices_treated_separatelytest_ensure_weight_tying_errors_with_different_indicestest_ensure_weight_tying_applied_with_same_indicestest_weight_tying_bc_same_indices_appliedtest_ensure_weight_tying_with_single_layerBackwards Compatibility
This implementation maintains full backwards compatibility:
✅ Default Behavior Unchanged:
ensure_weight_tyingdefaults toFalse, preserving existing behavior✅ No Breaking Changes: Existing code continues to work without modification
✅ Opt-in Enhancement: Users must explicitly set
ensure_weight_tying=Trueto use new features✅ BC Mode Preserved: When
ensure_weight_tying=False, existing automatic tying still works for compatible configurationsScreenshots
Checklist
cc: @BenjaminBossan