这是indexloc提供的服务,不要输入任何密码
Skip to content

Conversation

@githubnemo
Copy link
Collaborator

In the future, _tied_weights_keys will be mappings (mapping destination as key and source as value). This will also mean that the semantic of _tied_weights_keys will change from "the keys in this list are tied to the input embedding" to "this mapping defines tying between layers of any type". To this end we'll limit the scope and provide methods to retrieve input embedding ties.

If the need arises to retrieve more complicated mappings, we can do so at a later point.

In the future, `_tied_weights_keys` will be mappings (mapping destination as key and source as value).
This will also mean that the semantic of `_tied_weights_keys` will change from "the keys in this list are
tied to the input embedding" to "this mapping defines tying between layers of any type". To this end we'll
limit the scope and provide methods to retrieve input embedding ties.

If the need arises to retrieve more complicated mappings, we can do so at a later point.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for handling the forward compatibility of _tied_weights_keys. It's unfortunate that it's a more complicated issue.

I tried to understand the comment but struggled a bit. Probably you have some examples in mind where these extra steps are necessary. Would it make sense to write a small suite of unit tests for _get_module_names_tied_with_embedding with these edge cases? This would also be helpful as v5 is not released yet, so the part of the code dealing with dicts is not covered yet.

def _get_module_names_tied_with_embedding(model):
"""
Get the list of the fully qualified names of the modules that are tied to the input embeddings. In case of a
source-target-mapping `_tied_weights_keys`, it will attempt to identify the input embedding weights.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the second sentence confusing ("In case of"), it makes it sound like if it's a dict, the input embedding will be returned, but it means it's still the target that is returned, but the function tries to identify only those targets that are actually tied to the input embedding.

Also, let's mention that this function is for compatibility with = v5. Let's add a TODO for when <v5 is no longer supported.

return attention_mask


def _get_module_names_tied_with_embedding(model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add -> list[str].

"""
tied_weights = []

for name, module in model.named_modules():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding: We can't just check on the model itself because there can be submodules with their own weight tying?

input_embedding_params = set(module.get_input_embeddings().parameters())
candidates = [n for n, p in module.named_parameters(remove_duplicate=False) if p in input_embedding_params]

tied_weights.extend(k for k, v in module._tied_weights_keys.items() if k in candidates)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do f"{name}.{k}" if name else k for k here too? Also, should we check if v in candidates instead of k?

Copy link
Collaborator Author

@githubnemo githubnemo Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. I assumed so because the previous code walked through submodules recursively. Skimming transformers models, I haven't found any model that recursively defines tied weights keys.

So, if this isn't the case we may be able to get away with checking get_input_embeddigns and _tied_weights_keys top-level. Not sure if we can align those values consistently though (_tied_weights_keys will always return base-model-relative values while get_input_embeddings resolution will be peft-model-based unless we do it the way it is currently solved).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed internally, let's keep the check on sub-modules, e.g. for the case where the transformers model is wrapped inside another nn.Module.

nemo added 2 commits November 10, 2025 18:17
- don't loop over all modules, assume that the module tying is
  defined on the top-level (there's no precedent for the opposite yet)
- make sure that only the base model is considered to prevent
  duplicates due to `getattr` forwarding by PEFT

Also added a few more tests.
@githubnemo
Copy link
Collaborator Author

I've taken some time to think about this and I want to propose a new variant which doesn't loop over all the modules but unpacks the model to look at the _tied_weights_keys attribute. This solves potential duplicates and prefix issues at the cost of function-local imports to avoid circular imports. Let me know what you think.

I've also added a few more unit tests.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I still have a few comments. Most are just minor things, but there is one about the logic in checking embedding candidates where I'm confused.

Comment on lines 1573 to 1582
from peft import PeftModel
from peft.tuners.tuners_utils import BaseTuner

tied_weights = []

if isinstance(model, PeftModel):
model = model.base_model

if isinstance(model, BaseTuner):
model = model.model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

if hasattr(model, "get_base_model"):
    model = get_base_model()

That's more succinct and requires no local imports. The false positive rate on this should be pretty small and users with custom models can implement that API.

Comment on lines 1599 to 1601
# the reason why we don't compute `candidates` once is that there might be a few levels of nesting
# so that the keys have various prefixes (e.g., `model.`). eventually we'll find the model that
# defines both input embedding and weight tying mapping, then the keys will match.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is no longer needed, right?

input_embedding_params = set(model.get_input_embeddings().parameters())
candidates = [n for n, p in model.named_parameters(remove_duplicate=False) if p in input_embedding_params]

tied_weights.extend(k for k, v in model._tied_weights_keys.items() if k in candidates)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So candidates contains the names of the parameters of the source (the embedding). _tied_weights_keys is a mapping from target to source. Therefore, in the loop, k is the target weight name and v is the source (embedding) weight name. Shouldn't we thus check if v in candidates?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it is irrelevant if tying works. If tying doesn't work it is better to use v I think, going to change this.

tied_weights.extend(k for k, v in model._tied_weights_keys.items() if k in candidates)
elif model._tied_weights_keys is not None:
# TODO remove this when transformers <v5 is no longer supported
tied_weights.extend(k for k in model._tied_weights_keys)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tied_weights.extend(k for k in model._tied_weights_keys)
tied_weights.extend(model._tied_weights_keys)

# TODO remove this when transformers <v5 is no longer supported
tied_weights.extend(k for k in model._tied_weights_keys)

return list({".".join(name.split(".")[:-1]) for name in tied_weights})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about sorted instead of list for a predictable output?

assert no_split_modules == {"CLIPEncoderLayer", "LlamaDecoderLayer"}


class TestGetTiedWithEmbedding:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TestGetTiedWithEmbedding:
class TestGetModulesTiedWithEmbedding:

{".".join(k.split(".")[:-1]) for k in self.model_tied_weights_mapping[model_id].keys()}
)

if tied_weights_type == "linear":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment that this is for transformers <v5 and below for >= v5.

def test_get_modules_tied_to_embedding(self, model_id, tied_weights_type):
model, expected = self.get_model(model_id, tied_weights_type)

modules = _get_module_names_tied_with_embedding(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do: modules = model._get_module_names_tied_with_embedding() and then there is no need to import _get_module_names_tied_with_embedding. It's also closer to what is executed in practice.

- PEFT modifications to the model (e.g., adding `(.*)?.base_layer.`) were not considered but
  should be handled now
- Future testing (transformers >5) needs to switch the roles in the tests (dict is provided, list must be simulated)
- local imports are replaced by using `hasattr` calls to check for specific attributes instead
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, what a mess. Just an issue with a debug print and a nit, otherwise LGTM.

model = model.model

if not hasattr(model, "_tied_weights_keys"):
print("NO TIED WEIGHTS KEYS")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("NO TIED WEIGHTS KEYS")
print("NO TIED WEIGHTS KEYS!!!!!!!")

# TODO remove this when transformers <v5 is no longer supported
tied_weights.extend(model._tied_weights_keys)

return sorted({".".join(name.split(".")[:-1]) for name in tied_weights})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
return sorted({".".join(name.split(".")[:-1]) for name in tied_weights})
return sorted({name.rpartition(".")[0] for name in tied_weights})

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM. Can be merged once CI is green.

@githubnemo githubnemo merged commit 52e8659 into huggingface:main Nov 12, 2025
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants