From 9b5e7ce21feb51977d11da4e6a0ed35f502dbfb5 Mon Sep 17 00:00:00 2001
From: Maruyama_Aya <china6280111@126.com>
Date: Thu, 8 Jun 2023 14:56:56 +0800
Subject: [PATCH] modify shell for check

---
 examples/images/dreambooth/colossalai.sh                  | 1 +
 examples/images/dreambooth/test_ci.sh                     | 1 +
 examples/images/dreambooth/train_dreambooth_colossalai.py | 5 +++++
 3 files changed, 7 insertions(+)

diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh
index b2a544928..db4562dbc 100755
--- a/examples/images/dreambooth/colossalai.sh
+++ b/examples/images/dreambooth/colossalai.sh
@@ -14,4 +14,5 @@ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
   --lr_scheduler="constant" \
   --lr_warmup_steps=0 \
   --num_class_images=200 \
+  --test_run=True \
   --placement="auto" \
diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh
index 0e3f6efa4..21f45adae 100644
--- a/examples/images/dreambooth/test_ci.sh
+++ b/examples/images/dreambooth/test_ci.sh
@@ -19,6 +19,7 @@ for plugin in "gemini"; do
   --learning_rate=5e-6 \
   --lr_scheduler="constant" \
   --lr_warmup_steps=0 \
+  --test_run=True \
   --num_class_images=200 \
   --placement="auto" # "cuda"
 done
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py
index 44bde9226..888b28de8 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai.py
@@ -198,6 +198,7 @@ def parse_args(input_args=None):
     parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
     parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
     parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+    parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
     parser.add_argument(
         "--hub_model_id",
         type=str,
@@ -267,6 +268,7 @@ class DreamBoothDataset(Dataset):
         class_prompt=None,
         size=512,
         center_crop=False,
+        test=False,
     ):
         self.size = size
         self.center_crop = center_crop
@@ -277,6 +279,8 @@ class DreamBoothDataset(Dataset):
             raise ValueError("Instance images root doesn't exists.")
 
         self.instance_images_path = list(Path(instance_data_root).iterdir())
+        if test:
+            self.instance_images_path = self.instance_images_path[:10]
         self.num_instance_images = len(self.instance_images_path)
         self.instance_prompt = instance_prompt
         self._length = self.num_instance_images
@@ -509,6 +513,7 @@ def main(args):
         tokenizer=tokenizer,
         size=args.resolution,
         center_crop=args.center_crop,
+        test=args.test_run
     )
 
     def collate_fn(examples):