[hotfix] fix grad accumulation plus clipping for gemini (#5002)

pull/5007/head
Baizhou Zhang 2023-11-02 17:59:10 +08:00 committed by GitHub
parent dc003c304c
commit d99b2c961a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 13 additions and 3 deletions

View File

@ -637,6 +637,7 @@ class Chunk:
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)
return self.grad_chunk

View File

@ -343,6 +343,7 @@ class GeminiDDP(ModelWrapper):
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
chunk.grad_chunk.l2_norm = None
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)

View File

@ -49,7 +49,10 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
@parameterize("use_grad_checkpoint", [False, True])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
):
init_device = get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
@ -63,6 +66,10 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data)
if use_grad_checkpoint:
gemini_model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
@ -77,7 +84,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
**placement_config,
)
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
rank = dist.get_rank()
@ -112,6 +119,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
check_grad(gemini_model, torch_model)
if (i + 1) % accum_iter == 0:
torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
torch_optim.step()
gemini_optim.step()
torch_optim.zero_grad()

View File

@ -88,7 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, max_norm=1.0)
model.train()
torch_model.train()