mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix grad accumulation plus clipping for gemini (#5002)
parent
dc003c304c
commit
d99b2c961a
|
@ -637,6 +637,7 @@ class Chunk:
|
||||||
# grad chunk is initialized, just reallocate cuda global chunk
|
# grad chunk is initialized, just reallocate cuda global chunk
|
||||||
self.grad_chunk.cuda_shard = None
|
self.grad_chunk.cuda_shard = None
|
||||||
self.grad_chunk.is_gathered = True
|
self.grad_chunk.is_gathered = True
|
||||||
|
self.grad_chunk.l2_norm = None
|
||||||
alloc_storage(self.grad_chunk.cuda_global_chunk)
|
alloc_storage(self.grad_chunk.cuda_global_chunk)
|
||||||
|
|
||||||
return self.grad_chunk
|
return self.grad_chunk
|
||||||
|
|
|
@ -343,6 +343,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
||||||
else:
|
else:
|
||||||
grad_chunk = chunk.grad_chunk
|
grad_chunk = chunk.grad_chunk
|
||||||
|
chunk.grad_chunk.l2_norm = None
|
||||||
|
|
||||||
# hold -> compute -> hold after bwd
|
# hold -> compute -> hold after bwd
|
||||||
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
||||||
|
|
|
@ -49,7 +49,10 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||||
@parameterize("keep_gathered", [False, True])
|
@parameterize("keep_gathered", [False, True])
|
||||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||||
@parameterize("master_weights", [False, True])
|
@parameterize("master_weights", [False, True])
|
||||||
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
|
@parameterize("use_grad_checkpoint", [False, True])
|
||||||
|
def exam_gemini_grad_acc(
|
||||||
|
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
|
||||||
|
):
|
||||||
init_device = get_current_device()
|
init_device = get_current_device()
|
||||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||||
iter(model_zoo.get_sub_registry(model_name).values())
|
iter(model_zoo.get_sub_registry(model_name).values())
|
||||||
|
@ -63,6 +66,10 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
|
||||||
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
|
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
|
||||||
torch_p.data.copy_(p.data)
|
torch_p.data.copy_(p.data)
|
||||||
|
|
||||||
|
if use_grad_checkpoint:
|
||||||
|
gemini_model.gradient_checkpointing_enable()
|
||||||
|
torch_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
|
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
|
||||||
config_dict[world_size]["chunk_size"] = 5000
|
config_dict[world_size]["chunk_size"] = 5000
|
||||||
|
@ -77,7 +84,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
|
||||||
**placement_config,
|
**placement_config,
|
||||||
)
|
)
|
||||||
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||||
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
|
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
|
@ -112,6 +119,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
|
||||||
check_grad(gemini_model, torch_model)
|
check_grad(gemini_model, torch_model)
|
||||||
|
|
||||||
if (i + 1) % accum_iter == 0:
|
if (i + 1) % accum_iter == 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
|
||||||
torch_optim.step()
|
torch_optim.step()
|
||||||
gemini_optim.step()
|
gemini_optim.step()
|
||||||
torch_optim.zero_grad()
|
torch_optim.zero_grad()
|
||||||
|
|
|
@ -88,7 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
|
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, max_norm=1.0)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
torch_model.train()
|
torch_model.train()
|
||||||
|
|
Loading…
Reference in New Issue