[NFC] polish colossalai/engine/gradient_handler/_moe_gradient_handler.py (#3260)

pull/3313/head
LuGY 2023-03-27 18:47:44 +08:00 committed by binmakeswell
parent 204ca2f09a
commit 1ff7d5bfa5
1 changed files with 46 additions and 45 deletions

View File

@ -1,45 +1,46 @@
from colossalai.core import global_context as gpc from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.registry import GRADIENT_HANDLER from colossalai.core import global_context as gpc
from colossalai.utils.moe import get_moe_epsize_param_dict from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler from colossalai.utils.moe import get_moe_epsize_param_dict
from ...context.parallel_mode import ParallelMode
from .utils import bucket_allreduce from ...context.parallel_mode import ParallelMode
from colossalai.context.moe_context import MOE_CONTEXT from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler): @GRADIENT_HANDLER.register_module
"""A helper class to handle all-reduce operations in a data parallel group and class MoeGradientHandler(BaseGradientHandler):
moe model parallel. A all-reduce collective communication will be operated in """A helper class to handle all-reduce operations in a data parallel group and
:func:`handle_gradient` among a data parallel group. moe model parallel. A all-reduce collective communication will be operated in
For better performance, it bucketizes the gradients of all parameters that are :func:`handle_gradient` among a data parallel group.
the same type to improve the efficiency of communication. For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate. Args:
optimizer (Optimizer): Optimizer for updating the parameters. model (Module): Model where the gradients accumulate.
""" optimizer (Optimizer): Optimizer for updating the parameters.
"""
def __init__(self, model, optimizer=None):
super().__init__(model, optimizer) def __init__(self, model, optimizer=None):
super().__init__(model, optimizer)
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group. def handle_gradient(self):
Then running an all-reduce operation for all parameters in experts """A method running an all-reduce operation in a data parallel group.
across moe model parallel group Then running an all-reduce operation for all parameters in experts
""" across moe model parallel group
global_data = gpc.data_parallel_size """
global_data = gpc.data_parallel_size
if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model) if global_data > 1:
epsize_param_dict = get_moe_epsize_param_dict(self._model)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group # epsize is 1, indicating the params are replicated among processes in data parallelism
# reduce gradients for all parameters in data parallelism # use the ParallelMode.DATA to get data parallel group
if 1 in epsize_param_dict: # reduce gradients for all parameters in data parallelism
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) if 1 in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA))
for ep_size in epsize_param_dict:
if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: for ep_size in epsize_param_dict:
bucket_allreduce(param_list=epsize_param_dict[ep_size], if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) bucket_allreduce(param_list=epsize_param_dict[ep_size],
group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group)