diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh new file mode 100644 index 000000000..e69de29bb diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 7c90b939a..9c72c06e7 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -153,7 +153,8 @@ def parse_args(input_args=None): "--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", + help= + "Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1", ) parser.add_argument( "--gradient_checkpointing", @@ -361,6 +362,9 @@ def main(args): else: colossalai.launch_from_torch(config={}, seed=args.seed) + local_rank = gpc.get_local_rank(ParallelMode.DATA) + world_size = gpc.get_world_size(ParallelMode.DATA) + if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) if not class_images_dir.exists(): @@ -388,7 +392,7 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", - disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images @@ -400,7 +404,7 @@ def main(args): del pipeline # Handle the repository creation - if gpc.get_local_rank(ParallelMode.DATA) == 0: + if local_rank == 0: if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) @@ -465,8 +469,9 @@ def main(args): if args.gradient_checkpointing: unet.enable_gradient_checkpointing() + assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1." if args.scale_lr: - args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) + args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size unet = gemini_zero_dpp(unet, args.placement) @@ -555,7 +560,7 @@ def main(args): args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Train! - total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps + total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps logger.info("***** Running training *****", ranks=[0]) logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) @@ -567,7 +572,7 @@ def main(args): logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0) + progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0) progress_bar.set_description("Steps") global_step = 0 @@ -644,7 +649,7 @@ def main(args): if global_step % args.save_steps == 0: torch.cuda.synchronize() torch_unet = get_static_torch_model(unet) - if gpc.get_local_rank(ParallelMode.DATA) == 0: + if local_rank == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=torch_unet, @@ -659,7 +664,7 @@ def main(args): torch.cuda.synchronize() unet = get_static_torch_model(unet) - if gpc.get_local_rank(ParallelMode.DATA) == 0: + if local_rank == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet,