这是indexloc提供的服务,不要输入任何密码
Skip to content

Why peft.utils.other.fsdp_auto_wrap_policy do not warp the module do not require grad? #2640

@Changlin-Lee

Description

@Changlin-Lee

In https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L977,

def fsdp_auto_wrap_policy(model):
    if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
        get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
    else:
        from accelerate.utils.dataclasses import get_module_class_from_name
    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy

    from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder

    default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model))
    transformer_cls_names_to_wrap = os.environ.get(
        "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
    ).split(",")
    transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding}
    for layer_class in transformer_cls_names_to_wrap:
        if len(layer_class) == 0:
            continue
        transformer_cls = get_module_class_from_name(model, layer_class)
        if transformer_cls is None:
            raise Exception("Could not find the transformer layer class to wrap in the model.")
        else:
            transformer_cls_to_wrap.add(transformer_cls)

    def lambda_policy_fn(module):
        if (
            len(list(module.named_children())) == 0
            and getattr(module, "weight", None) is not None
            and module.weight.requires_grad
        ):
            return True
        return False

    lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
    transformer_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls=transformer_cls_to_wrap,
    )

    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
    return auto_wrap_policy

the fsdp_auto_wrap_policy uses a lambda_policy_fn which does not warp the module does not require grad.
But in regular Lora training, the original network does not need grad.
That may cause every GPU still keep a full network copy even in FSDP FULLY SHARD.
Why the code design such a policy?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions