From cfd1d5ee4970ecf168a09ca6d5f187b4520eaea3 Mon Sep 17 00:00:00 2001 From: Haofan Wang Date: Wed, 11 Jan 2023 16:56:15 +0800 Subject: [PATCH] [example] fixed seed error in train_dreambooth_colossalai.py (#2445) --- examples/images/dreambooth/train_dreambooth_colossalai.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index b7e24bfe4..7c90b939a 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -355,10 +355,11 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"): def main(args): - colossalai.launch_from_torch(config={}) - if args.seed is not None: - gpc.set_seed(args.seed) + if args.seed is None: + colossalai.launch_from_torch(config={}) + else: + colossalai.launch_from_torch(config={}, seed=args.seed) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir)