From 9b5e7ce21feb51977d11da4e6a0ed35f502dbfb5 Mon Sep 17 00:00:00 2001 From: Maruyama_Aya <china6280111@126.com> Date: Thu, 8 Jun 2023 14:56:56 +0800 Subject: [PATCH] modify shell for check --- examples/images/dreambooth/colossalai.sh | 1 + examples/images/dreambooth/test_ci.sh | 1 + examples/images/dreambooth/train_dreambooth_colossalai.py | 5 +++++ 3 files changed, 7 insertions(+) diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh index b2a544928..db4562dbc 100755 --- a/examples/images/dreambooth/colossalai.sh +++ b/examples/images/dreambooth/colossalai.sh @@ -14,4 +14,5 @@ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ + --test_run=True \ --placement="auto" \ diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh index 0e3f6efa4..21f45adae 100644 --- a/examples/images/dreambooth/test_ci.sh +++ b/examples/images/dreambooth/test_ci.sh @@ -19,6 +19,7 @@ for plugin in "gemini"; do --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ + --test_run=True \ --num_class_images=200 \ --placement="auto" # "cuda" done diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 44bde9226..888b28de8 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -198,6 +198,7 @@ def parse_args(input_args=None): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.") parser.add_argument( "--hub_model_id", type=str, @@ -267,6 +268,7 @@ class DreamBoothDataset(Dataset): class_prompt=None, size=512, center_crop=False, + test=False, ): self.size = size self.center_crop = center_crop @@ -277,6 +279,8 @@ class DreamBoothDataset(Dataset): raise ValueError("Instance images root doesn't exists.") self.instance_images_path = list(Path(instance_data_root).iterdir()) + if test: + self.instance_images_path = self.instance_images_path[:10] self.num_instance_images = len(self.instance_images_path) self.instance_prompt = instance_prompt self._length = self.num_instance_images @@ -509,6 +513,7 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + test=args.test_run ) def collate_fn(examples):