2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
|
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.registry import GRADIENT_HANDLER
|
|
|
|
from ._base_gradient_handler import BaseGradientHandler
|
|
|
|
from ...context.parallel_mode import ParallelMode
|
|
|
|
|
|
|
|
|
|
|
|
@GRADIENT_HANDLER.register_module
|
|
|
|
class DataParallelGradientHandler(BaseGradientHandler):
|
|
|
|
"""A helper class to handle all-reduce operations in a data parallel group.
|
|
|
|
A all-reduce collective communication will be operated in
|
|
|
|
:func:`handle_gradient` among a data parallel group.
|
|
|
|
For better performance, it bucketizes the gradients of all parameters that are
|
|
|
|
the same type to improve the efficiency of communication.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def handle_gradient(self):
|
|
|
|
"""A method running a all-reduce operation in a data parallel group.
|
|
|
|
"""
|
|
|
|
# TODO: add memory buffer
|
|
|
|
if gpc.data_parallel_size > 1:
|
|
|
|
# bucketize and all-reduce
|
|
|
|
buckets = {}
|
|
|
|
# Pack the buckets.
|
|
|
|
for param in self._model.parameters():
|
|
|
|
if param.requires_grad and param.grad is not None:
|
|
|
|
tp = param.data.type()
|
|
|
|
if tp not in buckets:
|
|
|
|
buckets[tp] = []
|
|
|
|
buckets[tp].append(param)
|
2021-12-20 15:26:19 +00:00
|
|
|
# param.main_grad = param.grad
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# For each bucket, all-reduce and copy all-reduced grads.
|
|
|
|
for tp in buckets:
|
|
|
|
bucket = buckets[tp]
|
|
|
|
grads = [param.grad.data for param in bucket]
|
|
|
|
coalesced = _flatten_dense_tensors(grads)
|
|
|
|
coalesced /= gpc.get_world_size(ParallelMode.DATA)
|
|
|
|
|
|
|
|
dist.all_reduce(
|
|
|
|
coalesced, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(
|
|
|
|
coalesced, grads)):
|
|
|
|
buf.copy_(synced)
|