mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix grad accumulation plus clipping for gemini (#5002)
parent
dc003c304c
commit
d99b2c961a
|
@ -637,6 +637,7 @@ class Chunk:
|
|||
# grad chunk is initialized, just reallocate cuda global chunk
|
||||
self.grad_chunk.cuda_shard = None
|
||||
self.grad_chunk.is_gathered = True
|
||||
self.grad_chunk.l2_norm = None
|
||||
alloc_storage(self.grad_chunk.cuda_global_chunk)
|
||||
|
||||
return self.grad_chunk
|
||||
|
|
|
@ -343,6 +343,7 @@ class GeminiDDP(ModelWrapper):
|
|||
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
||||
else:
|
||||
grad_chunk = chunk.grad_chunk
|
||||
chunk.grad_chunk.l2_norm = None
|
||||
|
||||
# hold -> compute -> hold after bwd
|
||||
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
||||
|
|
|
@ -49,7 +49,10 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
|||
@parameterize("keep_gathered", [False, True])
|
||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||
@parameterize("master_weights", [False, True])
|
||||
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
|
||||
@parameterize("use_grad_checkpoint", [False, True])
|
||||
def exam_gemini_grad_acc(
|
||||
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
|
||||
):
|
||||
init_device = get_current_device()
|
||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
|
@ -63,6 +66,10 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
|
|||
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
|
||||
torch_p.data.copy_(p.data)
|
||||
|
||||
if use_grad_checkpoint:
|
||||
gemini_model.gradient_checkpointing_enable()
|
||||
torch_model.gradient_checkpointing_enable()
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
|
@ -77,7 +84,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
|
|||
**placement_config,
|
||||
)
|
||||
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
|
||||
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
|
||||
|
||||
rank = dist.get_rank()
|
||||
|
||||
|
@ -112,6 +119,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
|
|||
check_grad(gemini_model, torch_model)
|
||||
|
||||
if (i + 1) % accum_iter == 0:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
|
||||
torch_optim.step()
|
||||
gemini_optim.step()
|
||||
torch_optim.zero_grad()
|
||||
|
|
|
@ -88,7 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
|||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
|
||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, max_norm=1.0)
|
||||
|
||||
model.train()
|
||||
torch_model.train()
|
||||
|
|
Loading…
Reference in New Issue