mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix PipelineSharedModuleGradientHandler (#1314)
parent
85f933b58b
commit
7c70bfbefa
|
@ -33,14 +33,19 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
||||||
# Pack the buckets.
|
# Pack the buckets.
|
||||||
for param in self._model.parameters():
|
for param in self._model.parameters():
|
||||||
group = getattr(param, 'pipeline_shared_module_pg', None)
|
group = getattr(param, 'pipeline_shared_module_pg', None)
|
||||||
if param.requires_grad and param.grad is not None and group is not None:
|
if param.requires_grad and group is not None and (
|
||||||
|
(hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null())
|
||||||
|
or param.grad is not None):
|
||||||
tp = param.data.type()
|
tp = param.data.type()
|
||||||
buckets[group][tp].append(param)
|
buckets[group][tp].append(param)
|
||||||
|
|
||||||
# For each bucket, all-reduce and copy all-reduced grads.
|
# For each bucket, all-reduce and copy all-reduced grads.
|
||||||
for group, group_buckets in buckets.items():
|
for group, group_buckets in buckets.items():
|
||||||
for tp, bucket in group_buckets.items():
|
for tp, bucket in group_buckets.items():
|
||||||
grads = [param.grad.data for param in bucket]
|
grads = [
|
||||||
|
param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data
|
||||||
|
for param in bucket
|
||||||
|
]
|
||||||
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
||||||
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
||||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||||
|
|
Loading…
Reference in New Issue