diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index bffd26f59..5eaa4af4d 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -6,7 +6,7 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} export GPUNUM=${GPUNUM:-1} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} -export TRAIN_STEP=${TRAIN_STEP:-2} +export TRAIN_STEP=${TRAIN_STEP:-10} # export PYTHONPATH=$PWD:$PYTHONPATH diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index bf1be87ba..4911ff124 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -66,11 +66,11 @@ class GPTLMLoss(nn.Module): def get_cpu_mem(): - return psutil.Process().memory_info().rss / 1024**2 # MB unit + return psutil.Process().memory_info().rss / 1024**2 def get_gpu_mem(): - return torch.cuda.memory_allocated() / 1024**2 # MB unit + return torch.cuda.memory_allocated() / 1024**2 def get_mem_info(prefix=""): @@ -78,7 +78,6 @@ def get_mem_info(prefix=""): def get_model_size(model: nn.Module): - # get the number of parameter of the model total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): @@ -130,7 +129,7 @@ def main(): WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" - PROF_FLAG = True # The flag of profiling, False by default + PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch() @@ -167,7 +166,7 @@ def main(): stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True ) elif args.distplan == "CAI_Gemini": - plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1) + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) else: raise RuntimeError @@ -249,7 +248,7 @@ def main(): prof.step() tflops_list.sort() - median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1) + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 517031c83..4d3981329 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -40,7 +40,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) -@parameterize("max_prefetch", [0, 1, 4]) +@parameterize("max_prefetch", [0, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_gpt_fwd_bwd( placement_config, diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index c2c11a8f3..002741389 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -50,7 +50,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) @parameterize("use_grad_checkpoint", [False, True]) -@parameterize("max_prefetch", [0, 1, 4]) +@parameterize("max_prefetch", [0, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_gemini_grad_acc( placement_config, diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index a0cbc7d60..c610259b2 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -40,7 +40,9 @@ EXAMPLE_MODELS = [ ] # bfloat16 cannot represent them exactly -BF16_IGNORED_KEYS = ["masked_bias"] +BF16_IGNORED_KEYS = [ + "masked_bias", +] def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): @@ -71,15 +73,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("master_weights", [True, False]) -@parameterize("max_prefetch", [0, 1, 4]) @parameterize("enable_async_reduce", [False, True]) def exam_model_step( - placement_config, - model_name: str, - mixed_precision: torch.dtype, - master_weights: bool, - max_prefetch: int, - enable_async_reduce=True, + placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True ): set_seed(42) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( @@ -108,7 +104,6 @@ def exam_model_step( **placement_config, mixed_precision=mixed_precision, master_weights=master_weights, - max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, ) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 3cbd36917..23e2d8083 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -28,8 +28,7 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"]) @parameterize("master_weights", [False, True]) -@parameterize("max_prefetch", [0, 1, 4]) -def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool, max_prefetch: int): +def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -45,14 +44,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP( - model, - config_dict, - **placement_config, - pin_memory=True, - master_weights=master_weights, - max_prefetch=max_prefetch, - ) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) model.train() zero_dict = model.state_dict(only_rank_0=False) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index a721c96a1..8d70ae3b1 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -20,8 +20,7 @@ PLACEMENT_CONFIGS = [ @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) -@parameterize("max_prefetch", [0, 1, 4]) -def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch): +def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) model_builder, data_gen_fn, output_transform_fn, *_ = next( iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) @@ -36,7 +35,7 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered, max_prefetch): config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, max_prefetch=max_prefetch) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) optimizer = HybridAdam(model.parameters()) optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32