2021-10-28 16:21:23 +00:00
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.registry import GRADIENT_HANDLER
|
2023-04-05 15:24:43 +00:00
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
from ...context.parallel_mode import ParallelMode
|
2023-04-05 15:24:43 +00:00
|
|
|
from ._base_gradient_handler import BaseGradientHandler
|
2022-03-18 08:38:32 +00:00
|
|
|
from .utils import bucket_allreduce
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
@GRADIENT_HANDLER.register_module
|
|
|
|
class DataParallelGradientHandler(BaseGradientHandler):
|
|
|
|
"""A helper class to handle all-reduce operations in a data parallel group.
|
2023-04-05 15:24:43 +00:00
|
|
|
A all-reduce collective communication will be operated in
|
2021-10-28 16:21:23 +00:00
|
|
|
:func:`handle_gradient` among a data parallel group.
|
2023-04-05 15:24:43 +00:00
|
|
|
For better performance, it bucketizes the gradients of all parameters that are
|
2021-10-28 16:21:23 +00:00
|
|
|
the same type to improve the efficiency of communication.
|
2022-04-26 02:00:18 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (Module): Model where the gradients accumulate.
|
|
|
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def handle_gradient(self):
|
|
|
|
"""A method running a all-reduce operation in a data parallel group.
|
|
|
|
"""
|
|
|
|
# TODO: add memory buffer
|
|
|
|
if gpc.data_parallel_size > 1:
|
2022-03-18 08:38:32 +00:00
|
|
|
bucket_allreduce(param_list=self._model.parameters(), group=gpc.get_group(ParallelMode.DATA))
|