Browse Source

[hotfix] fix PipelineSharedModuleGradientHandler (#1314)

pull/1319/head
ver217 2 years ago committed by GitHub
parent
commit
7c70bfbefa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 9
      colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py

9
colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py

@ -33,14 +33,19 @@ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
# Pack the buckets. # Pack the buckets.
for param in self._model.parameters(): for param in self._model.parameters():
group = getattr(param, 'pipeline_shared_module_pg', None) group = getattr(param, 'pipeline_shared_module_pg', None)
if param.requires_grad and param.grad is not None and group is not None: if param.requires_grad and group is not None and (
(hasattr(param, 'colo_attr') and not param.colo_attr.saved_grad.is_null())
or param.grad is not None):
tp = param.data.type() tp = param.data.type()
buckets[group][tp].append(param) buckets[group][tp].append(param)
# For each bucket, all-reduce and copy all-reduced grads. # For each bucket, all-reduce and copy all-reduced grads.
for group, group_buckets in buckets.items(): for group, group_buckets in buckets.items():
for tp, bucket in group_buckets.items(): for tp, bucket in group_buckets.items():
grads = [param.grad.data for param in bucket] grads = [
param.colo_attr.grad_payload if hasattr(param, 'colo_attr') else param.grad.data
for param in bucket
]
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device()) coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group) dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):

Loading…
Cancel
Save