From d99b2c961acdd39aa0aef976eef92f76dc9af44c Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 2 Nov 2023 17:59:10 +0800 Subject: [PATCH] [hotfix] fix grad accumulation plus clipping for gemini (#5002) --- colossalai/zero/gemini/chunk/chunk.py | 1 + colossalai/zero/gemini/gemini_ddp.py | 1 + tests/test_zero/test_gemini/test_grad_accum.py | 12 ++++++++++-- tests/test_zero/test_gemini/test_grad_clip.py | 2 +- 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index d3309fc53..4ea6cc662 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -637,6 +637,7 @@ class Chunk: # grad chunk is initialized, just reallocate cuda global chunk self.grad_chunk.cuda_shard = None self.grad_chunk.is_gathered = True + self.grad_chunk.l2_norm = None alloc_storage(self.grad_chunk.cuda_global_chunk) return self.grad_chunk diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index df7e1163c..565f50c90 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -343,6 +343,7 @@ class GeminiDDP(ModelWrapper): grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk) else: grad_chunk = chunk.grad_chunk + chunk.grad_chunk.l2_norm = None # hold -> compute -> hold after bwd grad_chunk.tensor_trans_state(p, TensorState.COMPUTE) diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 5e36b1838..bfd3ebfcb 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -49,7 +49,10 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("keep_gathered", [False, True]) @parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) -def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool): +@parameterize("use_grad_checkpoint", [False, True]) +def exam_gemini_grad_acc( + placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool +): init_device = get_current_device() model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( iter(model_zoo.get_sub_registry(model_name).values()) @@ -63,6 +66,10 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): torch_p.data.copy_(p.data) + if use_grad_checkpoint: + gemini_model.gradient_checkpointing_enable() + torch_model.gradient_checkpointing_enable() + world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 @@ -77,7 +84,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, **placement_config, ) optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) - gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1) + gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0) rank = dist.get_rank() @@ -112,6 +119,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, check_grad(gemini_model, torch_model) if (i + 1) % accum_iter == 0: + torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0) torch_optim.step() gemini_optim.step() torch_optim.zero_grad() diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index c3a36d3ba..23b3504fd 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -88,7 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): ) optimizer = HybridAdam(model.parameters(), lr=1e-3) - zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) + zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, max_norm=1.0) model.train() torch_model.train()