mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #2499 from feifeibear/dev0116_10
[example] check dreambooth example gradient accmulation must be 1pull/2502/head
commit
304f1ba124
|
@ -153,7 +153,8 @@ def parse_args(input_args=None):
|
||||||
"--gradient_accumulation_steps",
|
"--gradient_accumulation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
help=
|
||||||
|
"Number of updates steps to accumulate before performing a backward/update pass. If using Gemini, it must be 1",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gradient_checkpointing",
|
"--gradient_checkpointing",
|
||||||
|
@ -361,6 +362,9 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
|
||||||
|
local_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||||
|
world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||||
|
|
||||||
if args.with_prior_preservation:
|
if args.with_prior_preservation:
|
||||||
class_images_dir = Path(args.class_data_dir)
|
class_images_dir = Path(args.class_data_dir)
|
||||||
if not class_images_dir.exists():
|
if not class_images_dir.exists():
|
||||||
|
@ -388,7 +392,7 @@ def main(args):
|
||||||
for example in tqdm(
|
for example in tqdm(
|
||||||
sample_dataloader,
|
sample_dataloader,
|
||||||
desc="Generating class images",
|
desc="Generating class images",
|
||||||
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
|
disable=not local_rank == 0,
|
||||||
):
|
):
|
||||||
images = pipeline(example["prompt"]).images
|
images = pipeline(example["prompt"]).images
|
||||||
|
|
||||||
|
@ -400,7 +404,7 @@ def main(args):
|
||||||
del pipeline
|
del pipeline
|
||||||
|
|
||||||
# Handle the repository creation
|
# Handle the repository creation
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if local_rank == 0:
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
if args.hub_model_id is None:
|
if args.hub_model_id is None:
|
||||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||||
|
@ -465,8 +469,9 @@ def main(args):
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
assert args.gradient_accumulation_steps == 1, "if using ColossalAI gradient_accumulation_steps must be set to 1."
|
||||||
if args.scale_lr:
|
if args.scale_lr:
|
||||||
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
|
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * world_size
|
||||||
|
|
||||||
unet = gemini_zero_dpp(unet, args.placement)
|
unet = gemini_zero_dpp(unet, args.placement)
|
||||||
|
|
||||||
|
@ -555,7 +560,7 @@ def main(args):
|
||||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * world_size * args.gradient_accumulation_steps
|
||||||
|
|
||||||
logger.info("***** Running training *****", ranks=[0])
|
logger.info("***** Running training *****", ranks=[0])
|
||||||
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
|
||||||
|
@ -567,7 +572,7 @@ def main(args):
|
||||||
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
|
||||||
|
|
||||||
# Only show the progress bar once on each machine.
|
# Only show the progress bar once on each machine.
|
||||||
progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0)
|
progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
|
||||||
progress_bar.set_description("Steps")
|
progress_bar.set_description("Steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
|
@ -644,7 +649,7 @@ def main(args):
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch_unet = get_static_torch_model(unet)
|
torch_unet = get_static_torch_model(unet)
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if local_rank == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
unet=torch_unet,
|
unet=torch_unet,
|
||||||
|
@ -659,7 +664,7 @@ def main(args):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
unet = get_static_torch_model(unet)
|
unet = get_static_torch_model(unet)
|
||||||
|
|
||||||
if gpc.get_local_rank(ParallelMode.DATA) == 0:
|
if local_rank == 0:
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
|
|
Loading…
Reference in New Issue