2021-12-30 07:56:46 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
|
2022-03-21 08:55:37 +00:00
|
|
|
from collections import defaultdict
|
2021-12-30 07:56:46 +00:00
|
|
|
|
2022-03-21 08:55:37 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2023-04-05 15:24:43 +00:00
|
|
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|
|
|
|
2021-12-30 07:56:46 +00:00
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.registry import GRADIENT_HANDLER
|
2022-03-21 08:55:37 +00:00
|
|
|
|
2021-12-30 07:56:46 +00:00
|
|
|
from ._base_gradient_handler import BaseGradientHandler
|
|
|
|
|
|
|
|
|
|
|
|
@GRADIENT_HANDLER.register_module
|
|
|
|
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
|
|
|
"""A helper class to handle all-reduce operations in sub parallel groups.
|
2023-04-05 15:24:43 +00:00
|
|
|
A all-reduce collective communication will be operated in
|
2021-12-30 07:56:46 +00:00
|
|
|
:func:`handle_gradient` among all sub pipeline parallel groups.
|
2023-04-05 15:24:43 +00:00
|
|
|
For better performance, it bucketizes the gradients of all parameters that are
|
2021-12-30 07:56:46 +00:00
|
|
|
the same type to improve the efficiency of communication.
|
2022-04-26 02:00:18 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (Module): Model where the gradients accumulate.
|
|
|
|
optimizer (Optimizer): Optimizer for updating the parameters.
|
2021-12-30 07:56:46 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def handle_gradient(self):
|
|
|
|
"""A method running a all-reduce operation in sub pipeline parallel groups.
|
|
|
|
"""
|
|
|
|
if gpc.pipeline_parallel_size > 1:
|
|
|
|
# bucketize and all-reduce
|
|
|
|
buckets = defaultdict(lambda: defaultdict(list))
|
|
|
|
# Pack the buckets.
|
|
|
|
for param in self._model.parameters():
|
|
|
|
group = getattr(param, 'pipeline_shared_module_pg', None)
|
2022-07-14 09:31:13 +00:00
|
|
|
if param.requires_grad and group is not None and (
|
|
|
|
(hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null())
|
|
|
|
or param.grad is not None):
|
2021-12-30 07:56:46 +00:00
|
|
|
tp = param.data.type()
|
|
|
|
buckets[group][tp].append(param)
|
|
|
|
|
|
|
|
# For each bucket, all-reduce and copy all-reduced grads.
|
|
|
|
for group, group_buckets in buckets.items():
|
|
|
|
for tp, bucket in group_buckets.items():
|
2022-07-14 09:31:13 +00:00
|
|
|
grads = [
|
|
|
|
param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data
|
|
|
|
for param in bucket
|
|
|
|
]
|
2022-03-21 08:55:37 +00:00
|
|
|
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
2021-12-30 07:56:46 +00:00
|
|
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
|
|
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
|
|
|
buf.copy_(synced)
|