这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions examples/oft_dreambooth/oft_dreambooth_inference.ipynb

Large diffs are not rendered by default.

69 changes: 42 additions & 27 deletions examples/oft_dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def parse_args(input_args=None):

# oft args
parser.add_argument("--use_oft", action="store_true", help="Whether to use OFT for parameter efficient tuning")
parser.add_argument("--oft_r", type=int, default=8, help="OFT rank, only used if use_oft is True")
parser.add_argument("--oft_alpha", type=int, default=32, help="OFT alpha, only used if use_oft is True")
parser.add_argument("--oft_r", type=int, default=0, help="OFT rank, only used if use_oft is True")
parser.add_argument("--oft_block_size", type=int, default=32, help="OFT block size, only used if use_oft is True")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oft doesn't have alpha, it's oft_block_size, and align default value of r to OFTConfig

parser.add_argument("--oft_dropout", type=float, default=0.0, help="OFT dropout, only used if use_oft is True")
parser.add_argument(
"--oft_use_coft", action="store_true", help="Using constrained OFT, only used if use_oft is True"
Expand All @@ -196,14 +196,14 @@ def parse_args(input_args=None):
parser.add_argument(
"--oft_text_encoder_r",
type=int,
default=8,
default=0,
help="OFT rank for text encoder, only used if `use_oft` and `train_text_encoder` are True",
)
parser.add_argument(
"--oft_text_encoder_alpha",
"--oft_text_encoder_block_size",
type=int,
default=32,
help="OFT alpha for text encoder, only used if `use_oft` and `train_text_encoder` are True",
help="OFT block size for text encoder, only used if `use_oft` and `train_text_encoder` are True",
)
parser.add_argument(
"--oft_text_encoder_dropout",
Expand Down Expand Up @@ -287,7 +287,7 @@ def parse_args(input_args=None):
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
help="Scale the learning rate by the number of accelerators, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
Expand Down Expand Up @@ -369,7 +369,7 @@ def parse_args(input_args=None):
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" 1.10.and an Nvidia Ampere GPU or Intel XPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
Expand All @@ -380,7 +380,7 @@ def parse_args(input_args=None):
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
" 1.10.and an Nvidia Ampere GPU or Intel XPU. Default to fp16 if a GPU/XPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
Expand Down Expand Up @@ -420,10 +420,12 @@ def b2mb(x):
# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
def __enter__(self):
self.device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
self.device_module = getattr(torch, self.device_type, torch.cuda)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
self.device_module.empty_cache()
self.device_module.reset_peak_memory_stats() # reset the peak gauge to zero
self.begin = self.device_module.memory_allocated()
self.process = psutil.Process()

self.cpu_begin = self.cpu_mem_used()
Expand Down Expand Up @@ -453,9 +455,9 @@ def __exit__(self, *exc):
self.peak_monitoring = False

gc.collect()
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
self.device_module.empty_cache()
self.end = self.device_module.memory_allocated()
self.peak = self.device_module.max_memory_allocated()
self.used = b2mb(self.end - self.begin)
self.peaked = b2mb(self.peak - self.begin)

Expand Down Expand Up @@ -569,7 +571,7 @@ def collate_fn(examples, with_prior_preservation=False):


class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
"A simple dataset to prepare the prompts to generate class images on multiple accelerators."

def __init__(self, prompt, num_samples):
self.prompt = prompt
Expand Down Expand Up @@ -636,7 +638,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
torch_dtype = torch.float16 if accelerator.device.type in ["cuda", "xpu"] else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
Expand Down Expand Up @@ -673,6 +675,8 @@ def main(args):
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -725,7 +729,7 @@ def main(args):
if args.use_oft:
config = OFTConfig(
r=args.oft_r,
alpha=args.oft_alpha,
oft_block_size=args.oft_block_size,
target_modules=UNET_TARGET_MODULES,
module_dropout=args.oft_dropout,
init_weights=True,
Expand All @@ -742,7 +746,7 @@ def main(args):
elif args.train_text_encoder and args.use_oft:
config = OFTConfig(
r=args.oft_text_encoder_r,
alpha=args.oft_text_encoder_alpha,
oft_block_size=args.oft_text_encoder_block_size,
target_modules=TEXT_ENCODER_TARGET_MODULES,
module_dropout=args.oft_text_encoder_dropout,
init_weights=True,
Expand All @@ -754,7 +758,9 @@ def main(args):
print(text_encoder)

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
if accelerator.device.type == "xpu":
logger.warn("XPU hasn't support xformers yet, ignore it.")
elif is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
Expand All @@ -767,15 +773,15 @@ def main(args):

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True

if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB accelerators
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -1040,18 +1046,27 @@ def main(args):
)

del pipeline
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()

if global_step >= args.max_train_steps:
break
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage

# Printing the accelerator memory usage details such as allocated memory, peak memory, and total memory usage
if not args.no_tracemalloc:
accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}")
accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
accelerator.print(
f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
f"{accelerator.device.type.upper()} Memory before entering the train : {b2mb(tracemalloc.begin)}"
)
accelerator.print(
f"{accelerator.device.type.upper()} Memory consumed at the end of the train (end-begin): {tracemalloc.used}"
)
accelerator.print(
f"{accelerator.device.type.upper()} Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}"
)
accelerator.print(
f"{accelerator.device.type.upper()} Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
)

accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")
Expand Down
Loading