|
|
|
@ -112,13 +112,6 @@ def is_model_parallel_parameter(p):
|
|
|
|
|
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_moe_parallel_parameter(p): |
|
|
|
|
# FIXME(HHC): clip_grad need to changed to adapted for MoE |
|
|
|
|
# This return value must set to False, otherwise it will raise |
|
|
|
|
# an error in training |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _calc_l2_norm(grads): |
|
|
|
|
norm = 0.0 |
|
|
|
|
if len(grads) > 0: |
|
|
|
@ -214,14 +207,11 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|
|
|
|
else: |
|
|
|
|
tensor_parallel_grads = [] |
|
|
|
|
no_tensor_parallel_grads = [] |
|
|
|
|
moe_parallel_grads = [] # used to collect moe tensor parallel gradients |
|
|
|
|
zero_sharded_grads = [] |
|
|
|
|
for p in params: |
|
|
|
|
if is_model_parallel_parameter(p): |
|
|
|
|
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) |
|
|
|
|
tensor_parallel_grads.append(p.grad.data / reductor) |
|
|
|
|
elif is_moe_parallel_parameter(p): |
|
|
|
|
moe_parallel_grads.append(p.grad.data) |
|
|
|
|
elif hasattr(p, 'zero_is_sharded'): |
|
|
|
|
zero_sharded_grads.append(p.grad.data) |
|
|
|
|
else: |
|
|
|
@ -230,28 +220,21 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
|
|
|
|
|
if norm_type == 2.0 and enable_cuda_kernels: |
|
|
|
|
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type |
|
|
|
|
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type |
|
|
|
|
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type |
|
|
|
|
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type |
|
|
|
|
else: |
|
|
|
|
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) |
|
|
|
|
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) |
|
|
|
|
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type) |
|
|
|
|
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) |
|
|
|
|
|
|
|
|
|
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors |
|
|
|
|
if not enable_cuda_kernels: |
|
|
|
|
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) |
|
|
|
|
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) |
|
|
|
|
moe_parallel_norm = _move_norm_to_cuda(moe_parallel_norm) |
|
|
|
|
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) |
|
|
|
|
|
|
|
|
|
# Sum across all model-parallel GPUs. |
|
|
|
|
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: |
|
|
|
|
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) |
|
|
|
|
# Sum across all moe-tensor-parallel GPUs |
|
|
|
|
if len(moe_parallel_grads) > 0: |
|
|
|
|
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL)) |
|
|
|
|
no_tensor_parallel_norm += moe_parallel_norm |
|
|
|
|
# Sum across all zero sharded GPUs |
|
|
|
|
if len(zero_sharded_grads) > 0: |
|
|
|
|
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) |
|
|
|
|