From 039854b39165ab7f2a4fa7ab3d67e47daa325d1c Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Thu, 8 Jun 2023 13:17:58 +0800 Subject: [PATCH] modify shell for check --- examples/images/dreambooth/test_ci.sh | 6 +++--- examples/images/dreambooth/train_dreambooth_colossalai.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh index 8d18e1d4a..35c81b325 100644 --- a/examples/images/dreambooth/test_ci.sh +++ b/examples/images/dreambooth/test_ci.sh @@ -6,8 +6,8 @@ HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 -# "torch_ddp" "torch_ddp_fp16" -for plugin in "low_level_zero" "gemini"; do +# "torch_ddp" "torch_ddp_fp16" "low_level_zero" +for plugin in "gemini"; do torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \ --instance_data_dir="/data/dreambooth/Teyvat/data" \ @@ -20,5 +20,5 @@ for plugin in "low_level_zero" "gemini"; do --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --placement="cuda" + --placement="cpu" # "cuda" done diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index eae52b5ec..44bde9226 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -487,7 +487,7 @@ def main(args): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5) + plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)