[NFC] polish colossalai/engine/gradient_handler/utils.py code style (#2708)

pull/2709/head
CZYCW 2 years ago committed by GitHub
parent 6427c406cf
commit 4ac8bfb072
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,29 +1,30 @@
import torch.distributed as dist from typing import Iterable
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import torch.distributed as dist
from typing import Iterable import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
# get communication world size def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
comm_size = dist.get_world_size(group) # get communication world size
# bucketize and all-reduce comm_size = dist.get_world_size(group)
buckets = {} # bucketize and all-reduce
# Pack the buckets. buckets = {}
for param in param_list: # Pack the buckets.
if param.requires_grad and param.grad is not None: for param in param_list:
tp = param.data.type() if param.requires_grad and param.grad is not None:
if tp not in buckets: tp = param.data.type()
buckets[tp] = [] if tp not in buckets:
buckets[tp].append(param) buckets[tp] = []
buckets[tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads.
for tp in buckets: # For each bucket, all-reduce and copy all-reduced grads.
bucket = buckets[tp] for tp in buckets:
grads = [param.grad.data for param in bucket] bucket = buckets[tp]
coalesced = _flatten_dense_tensors(grads) grads = [param.grad.data for param in bucket]
coalesced /= comm_size coalesced = _flatten_dense_tensors(grads)
coalesced /= comm_size
dist.all_reduce(coalesced, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): dist.all_reduce(coalesced, group=group)
buf.copy_(synced) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

Loading…
Cancel
Save