diff --git a/examples/stable_diffusion/train_dreambooth.py b/examples/stable_diffusion/train_dreambooth.py index 56406ed082..9af5fee335 100644 --- a/examples/stable_diffusion/train_dreambooth.py +++ b/examples/stable_diffusion/train_dreambooth.py @@ -325,7 +325,7 @@ def parse_args(input_args=None): "--scale_lr", action="store_true", default=False, - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + help="Scale the learning rate by the number of accelerators, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", @@ -407,7 +407,7 @@ def parse_args(input_args=None): choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " 1.10. and an Nvidia Ampere GPU or Intel XPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) @@ -418,7 +418,7 @@ def parse_args(input_args=None): choices=["no", "fp32", "fp16", "bf16"], help=( "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + " 1.10. and an Nvidia Ampere GPU or Intel XPU. Default to fp16 if an accelerator is available else fp32." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -585,9 +585,11 @@ def b2mb(x): class TorchTracemalloc: def __enter__(self): gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero - self.begin = torch.cuda.memory_allocated() + self.device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" + self.device_module = getattr(torch, self.device_type, torch.cuda) + self.device_module.empty_cache() + self.device_module.reset_peak_memory_stats() # reset the peak gauge to zero + self.begin = self.device_module.memory_allocated() self.process = psutil.Process() self.cpu_begin = self.cpu_mem_used() @@ -617,9 +619,9 @@ def __exit__(self, *exc): self.peak_monitoring = False gc.collect() - torch.cuda.empty_cache() - self.end = torch.cuda.memory_allocated() - self.peak = torch.cuda.max_memory_allocated() + self.device_module.empty_cache() + self.end = self.device_module.memory_allocated() + self.peak = self.device_module.max_memory_allocated() self.used = b2mb(self.end - self.begin) self.peaked = b2mb(self.peak - self.begin) @@ -733,7 +735,7 @@ def collate_fn(examples, with_prior_preservation=False): class PromptDataset(Dataset): - "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + "A simple dataset to prepare the prompts to generate class images on multiple accelerators." def __init__(self, prompt, num_samples): self.prompt = prompt @@ -800,7 +802,7 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + torch_dtype = torch.float16 if accelerator.device.type in ["cuda", "xpu"] else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 elif args.prior_generation_precision == "fp16": @@ -835,8 +837,11 @@ def main(args): image.save(image_filename) del pipeline + if torch.cuda.is_available(): torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() # Handle the repository creation if accelerator.is_main_process: @@ -902,7 +907,9 @@ def main(args): print(text_encoder) if args.enable_xformers_memory_efficient_attention: - if is_xformers_available(): + if accelerator.device.type == "xpu": + logger.warn("XPU hasn't support xformers yet, ignore it.") + elif is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") @@ -914,7 +921,7 @@ def main(args): # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if args.allow_tf32: + if args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: @@ -922,7 +929,7 @@ def main(args): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB accelerators if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -1201,16 +1208,25 @@ def main(args): pipeline.text_encoder.train() del pipeline - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.xpu.is_available(): + torch.xpu.empty_cache() if global_step >= args.max_train_steps: break - # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage - accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}") - accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") - accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + # Printing the accelerator memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print( + f"{accelerator.device.type.upper()} Memory before entering the train : {b2mb(tracemalloc.begin)}" + ) + accelerator.print( + f"{accelerator.device.type.upper()} Memory consumed at the end of the train (end-begin): {tracemalloc.used}" + ) + accelerator.print( + f"{accelerator.device.type.upper()} Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}" + ) accelerator.print( - f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + f"{accelerator.device.type.upper()} Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" ) accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}")