[MOE] remove old MoE legacy (#493)

pull/496/head
HELSON 3 years ago committed by GitHub
parent c4c02424f3
commit f24b5ed201
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,7 +45,3 @@ class ParallelMode(Enum):
PARALLEL_2P5D_COL = '2p5d_col' PARALLEL_2P5D_COL = '2p5d_col'
PARALLEL_2P5D_DEP = '2p5d_dep' PARALLEL_2P5D_DEP = '2p5d_dep'
PARALLEL_2P5D_XZ = '2p5d_xz' PARALLEL_2P5D_XZ = '2p5d_xz'
# MOE parallel
MOE_DATA = 'moe_data'
MOE_MODEL = 'moe_model'

@ -2,10 +2,10 @@ from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_moe_parallel_parameter, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_no_pp_or_last_stage, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
switch_virtual_pipeline_parallel_rank, sync_model_param) sync_model_param)
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage from .memory import report_memory_usage
@ -18,5 +18,5 @@ __all__ = [
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter', 'TensorDetector' 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector'
] ]

@ -112,13 +112,6 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
def is_moe_parallel_parameter(p):
# FIXME(HHC): clip_grad need to changed to adapted for MoE
# This return value must set to False, otherwise it will raise
# an error in training
return False
def _calc_l2_norm(grads): def _calc_l2_norm(grads):
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
@ -214,14 +207,11 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []
no_tensor_parallel_grads = [] no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
zero_sharded_grads = [] zero_sharded_grads = []
for p in params: for p in params:
if is_model_parallel_parameter(p): if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type) reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor) tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data)
elif hasattr(p, 'zero_is_sharded'): elif hasattr(p, 'zero_is_sharded'):
zero_sharded_grads.append(p.grad.data) zero_sharded_grads.append(p.grad.data)
else: else:
@ -230,28 +220,21 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if norm_type == 2.0 and enable_cuda_kernels: if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type zero_sharded_norm = _calc_l2_norm(zero_sharded_grads)**norm_type
else: else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels: if not enable_cuda_kernels:
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm) tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm) no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
moe_parallel_norm = _move_norm_to_cuda(moe_parallel_norm)
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm) zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0: if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR)) dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
# Sum across all moe-tensor-parallel GPUs
if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm
# Sum across all zero sharded GPUs # Sum across all zero sharded GPUs
if len(zero_sharded_grads) > 0: if len(zero_sharded_grads) > 0:
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA)) dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))

Loading…
Cancel
Save