-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add forward compat. for tied_weights_keys dicts #2902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add forward compat. for tied_weights_keys dicts #2902
Conversation
In the future, `_tied_weights_keys` will be mappings (mapping destination as key and source as value). This will also mean that the semantic of `_tied_weights_keys` will change from "the keys in this list are tied to the input embedding" to "this mapping defines tying between layers of any type". To this end we'll limit the scope and provide methods to retrieve input embedding ties. If the need arises to retrieve more complicated mappings, we can do so at a later point.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for handling the forward compatibility of _tied_weights_keys. It's unfortunate that it's a more complicated issue.
I tried to understand the comment but struggled a bit. Probably you have some examples in mind where these extra steps are necessary. Would it make sense to write a small suite of unit tests for _get_module_names_tied_with_embedding with these edge cases? This would also be helpful as v5 is not released yet, so the part of the code dealing with dicts is not covered yet.
src/peft/utils/other.py
Outdated
| def _get_module_names_tied_with_embedding(model): | ||
| """ | ||
| Get the list of the fully qualified names of the modules that are tied to the input embeddings. In case of a | ||
| source-target-mapping `_tied_weights_keys`, it will attempt to identify the input embedding weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the second sentence confusing ("In case of"), it makes it sound like if it's a dict, the input embedding will be returned, but it means it's still the target that is returned, but the function tries to identify only those targets that are actually tied to the input embedding.
Also, let's mention that this function is for compatibility with = v5. Let's add a TODO for when <v5 is no longer supported.
src/peft/utils/other.py
Outdated
| return attention_mask | ||
|
|
||
|
|
||
| def _get_module_names_tied_with_embedding(model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add -> list[str].
src/peft/utils/other.py
Outdated
| """ | ||
| tied_weights = [] | ||
|
|
||
| for name, module in model.named_modules(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding: We can't just check on the model itself because there can be submodules with their own weight tying?
src/peft/utils/other.py
Outdated
| input_embedding_params = set(module.get_input_embeddings().parameters()) | ||
| candidates = [n for n, p in module.named_parameters(remove_duplicate=False) if p in input_embedding_params] | ||
|
|
||
| tied_weights.extend(k for k, v in module._tied_weights_keys.items() if k in candidates) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to do f"{name}.{k}" if name else k for k here too? Also, should we check if v in candidates instead of k?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question. I assumed so because the previous code walked through submodules recursively. Skimming transformers models, I haven't found any model that recursively defines tied weights keys.
So, if this isn't the case we may be able to get away with checking get_input_embeddigns and _tied_weights_keys top-level. Not sure if we can align those values consistently though (_tied_weights_keys will always return base-model-relative values while get_input_embeddings resolution will be peft-model-based unless we do it the way it is currently solved).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed internally, let's keep the check on sub-modules, e.g. for the case where the transformers model is wrapped inside another nn.Module.
- don't loop over all modules, assume that the module tying is defined on the top-level (there's no precedent for the opposite yet) - make sure that only the base model is considered to prevent duplicates due to `getattr` forwarding by PEFT Also added a few more tests.
|
I've taken some time to think about this and I want to propose a new variant which doesn't loop over all the modules but unpacks the model to look at the I've also added a few more unit tests. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. I still have a few comments. Most are just minor things, but there is one about the logic in checking embedding candidates where I'm confused.
src/peft/utils/other.py
Outdated
| from peft import PeftModel | ||
| from peft.tuners.tuners_utils import BaseTuner | ||
|
|
||
| tied_weights = [] | ||
|
|
||
| if isinstance(model, PeftModel): | ||
| model = model.base_model | ||
|
|
||
| if isinstance(model, BaseTuner): | ||
| model = model.model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
if hasattr(model, "get_base_model"):
model = get_base_model()That's more succinct and requires no local imports. The false positive rate on this should be pretty small and users with custom models can implement that API.
src/peft/utils/other.py
Outdated
| # the reason why we don't compute `candidates` once is that there might be a few levels of nesting | ||
| # so that the keys have various prefixes (e.g., `model.`). eventually we'll find the model that | ||
| # defines both input embedding and weight tying mapping, then the keys will match. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is no longer needed, right?
src/peft/utils/other.py
Outdated
| input_embedding_params = set(model.get_input_embeddings().parameters()) | ||
| candidates = [n for n, p in model.named_parameters(remove_duplicate=False) if p in input_embedding_params] | ||
|
|
||
| tied_weights.extend(k for k, v in model._tied_weights_keys.items() if k in candidates) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So candidates contains the names of the parameters of the source (the embedding). _tied_weights_keys is a mapping from target to source. Therefore, in the loop, k is the target weight name and v is the source (embedding) weight name. Shouldn't we thus check if v in candidates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically it is irrelevant if tying works. If tying doesn't work it is better to use v I think, going to change this.
src/peft/utils/other.py
Outdated
| tied_weights.extend(k for k, v in model._tied_weights_keys.items() if k in candidates) | ||
| elif model._tied_weights_keys is not None: | ||
| # TODO remove this when transformers <v5 is no longer supported | ||
| tied_weights.extend(k for k in model._tied_weights_keys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| tied_weights.extend(k for k in model._tied_weights_keys) | |
| tied_weights.extend(model._tied_weights_keys) |
src/peft/utils/other.py
Outdated
| # TODO remove this when transformers <v5 is no longer supported | ||
| tied_weights.extend(k for k in model._tied_weights_keys) | ||
|
|
||
| return list({".".join(name.split(".")[:-1]) for name in tied_weights}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about sorted instead of list for a predictable output?
tests/test_other.py
Outdated
| assert no_split_modules == {"CLIPEncoderLayer", "LlamaDecoderLayer"} | ||
|
|
||
|
|
||
| class TestGetTiedWithEmbedding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class TestGetTiedWithEmbedding: | |
| class TestGetModulesTiedWithEmbedding: |
tests/test_other.py
Outdated
| {".".join(k.split(".")[:-1]) for k in self.model_tied_weights_mapping[model_id].keys()} | ||
| ) | ||
|
|
||
| if tied_weights_type == "linear": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment that this is for transformers <v5 and below for >= v5.
tests/test_other.py
Outdated
| def test_get_modules_tied_to_embedding(self, model_id, tied_weights_type): | ||
| model, expected = self.get_model(model_id, tied_weights_type) | ||
|
|
||
| modules = _get_module_names_tied_with_embedding(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do: modules = model._get_module_names_tied_with_embedding() and then there is no need to import _get_module_names_tied_with_embedding. It's also closer to what is executed in practice.
- PEFT modifications to the model (e.g., adding `(.*)?.base_layer.`) were not considered but should be handled now - Future testing (transformers >5) needs to switch the roles in the tests (dict is provided, list must be simulated) - local imports are replaced by using `hasattr` calls to check for specific attributes instead
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, what a mess. Just an issue with a debug print and a nit, otherwise LGTM.
src/peft/utils/other.py
Outdated
| model = model.model | ||
|
|
||
| if not hasattr(model, "_tied_weights_keys"): | ||
| print("NO TIED WEIGHTS KEYS") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| print("NO TIED WEIGHTS KEYS") | |
| print("NO TIED WEIGHTS KEYS!!!!!!!") |
src/peft/utils/other.py
Outdated
| # TODO remove this when transformers <v5 is no longer supported | ||
| tied_weights.extend(model._tied_weights_keys) | ||
|
|
||
| return sorted({".".join(name.split(".")[:-1]) for name in tied_weights}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
| return sorted({".".join(name.split(".")[:-1]) for name in tied_weights}) | |
| return sorted({name.rpartition(".")[0] for name in tied_weights}) |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM. Can be merged once CI is green.
In the future,
_tied_weights_keyswill be mappings (mapping destination as key and source as value). This will also mean that the semantic of_tied_weights_keyswill change from "the keys in this list are tied to the input embedding" to "this mapping defines tying between layers of any type". To this end we'll limit the scope and provide methods to retrieve input embedding ties.If the need arises to retrieve more complicated mappings, we can do so at a later point.