diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/engine/gradient_handler/utils.py index e92044b47..fca5f2ec9 100644 --- a/colossalai/engine/gradient_handler/utils.py +++ b/colossalai/engine/gradient_handler/utils.py @@ -1,29 +1,30 @@ -import torch.distributed as dist -import torch.nn as nn -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from typing import Iterable - - -def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None): - # get communication world size - comm_size = dist.get_world_size(group) - # bucketize and all-reduce - buckets = {} - # Pack the buckets. - for param in param_list: - if param.requires_grad and param.grad is not None: - tp = param.data.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(param) - - # For each bucket, all-reduce and copy all-reduced grads. - for tp in buckets: - bucket = buckets[tp] - grads = [param.grad.data for param in bucket] - coalesced = _flatten_dense_tensors(grads) - coalesced /= comm_size - - dist.all_reduce(coalesced, group=group) - for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): - buf.copy_(synced) +from typing import Iterable + +import torch.distributed as dist +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None): + # get communication world size + comm_size = dist.get_world_size(group) + # bucketize and all-reduce + buckets = {} + # Pack the buckets. + for param in param_list: + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + coalesced /= comm_size + + dist.all_reduce(coalesced, group=group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced)