-
Notifications
You must be signed in to change notification settings - Fork 2k
Open
Description
In https://github.com/huggingface/peft/blob/main/src/peft/utils/other.py#L977,
def fsdp_auto_wrap_policy(model):
if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
else:
from accelerate.utils.dataclasses import get_module_class_from_name
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model))
transformer_cls_names_to_wrap = os.environ.get(
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
).split(",")
transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding}
for layer_class in transformer_cls_names_to_wrap:
if len(layer_class) == 0:
continue
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_cls_to_wrap,
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy
the fsdp_auto_wrap_policy uses a lambda_policy_fn which does not warp the module does not require grad.
But in regular Lora training, the original network does not need grad.
That may cause every GPU still keep a full network copy even in FSDP FULLY SHARD.
Why the code design such a policy?
Metadata
Metadata
Assignees
Labels
No labels