mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/engine/gradient_handler/utils.py code style (#2708)
parent
6427c406cf
commit
4ac8bfb072
@ -1,29 +1,30 @@
|
|||||||
import torch.distributed as dist
|
from typing import Iterable
|
||||||
import torch.nn as nn
|
|
||||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
import torch.distributed as dist
|
||||||
from typing import Iterable
|
import torch.nn as nn
|
||||||
|
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||||
|
|
||||||
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
|
|
||||||
# get communication world size
|
def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
|
||||||
comm_size = dist.get_world_size(group)
|
# get communication world size
|
||||||
# bucketize and all-reduce
|
comm_size = dist.get_world_size(group)
|
||||||
buckets = {}
|
# bucketize and all-reduce
|
||||||
# Pack the buckets.
|
buckets = {}
|
||||||
for param in param_list:
|
# Pack the buckets.
|
||||||
if param.requires_grad and param.grad is not None:
|
for param in param_list:
|
||||||
tp = param.data.type()
|
if param.requires_grad and param.grad is not None:
|
||||||
if tp not in buckets:
|
tp = param.data.type()
|
||||||
buckets[tp] = []
|
if tp not in buckets:
|
||||||
buckets[tp].append(param)
|
buckets[tp] = []
|
||||||
|
buckets[tp].append(param)
|
||||||
# For each bucket, all-reduce and copy all-reduced grads.
|
|
||||||
for tp in buckets:
|
# For each bucket, all-reduce and copy all-reduced grads.
|
||||||
bucket = buckets[tp]
|
for tp in buckets:
|
||||||
grads = [param.grad.data for param in bucket]
|
bucket = buckets[tp]
|
||||||
coalesced = _flatten_dense_tensors(grads)
|
grads = [param.grad.data for param in bucket]
|
||||||
coalesced /= comm_size
|
coalesced = _flatten_dense_tensors(grads)
|
||||||
|
coalesced /= comm_size
|
||||||
dist.all_reduce(coalesced, group=group)
|
|
||||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
dist.all_reduce(coalesced, group=group)
|
||||||
buf.copy_(synced)
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
buf.copy_(synced)
|
||||||
|
Loading…
Reference in new issue