ColossalAI/colossalai/engine/gradient_handler/_data_parallel_gradient_han...

28 lines
1.1 KiB
Python
Raw Normal View History

2021-10-28 16:21:23 +00:00
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
2021-10-28 16:21:23 +00:00
from ...context.parallel_mode import ParallelMode
from ._base_gradient_handler import BaseGradientHandler
from .utils import bucket_allreduce
2021-10-28 16:21:23 +00:00
@GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group.
A all-reduce collective communication will be operated in
2021-10-28 16:21:23 +00:00
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
2021-10-28 16:21:23 +00:00
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
optimizer (Optimizer): Optimizer for updating the parameters.
2021-10-28 16:21:23 +00:00
"""
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
# TODO: add memory buffer
if gpc.data_parallel_size > 1:
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))