mirror of https://github.com/hpcaitech/ColossalAI
parent
5fb958cc83
commit
2069472e96
|
@ -338,14 +338,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
|
||||
else:
|
||||
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
|
||||
recieved_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
received_grad = torch.zeros_like(flat_grads_list[0])
|
||||
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
|
||||
|
||||
if recieved_grad.dtype != grad_dtype:
|
||||
recieved_grad = recieved_grad.to(grad_dtype)
|
||||
if received_grad.dtype != grad_dtype:
|
||||
received_grad = received_grad.to(grad_dtype)
|
||||
|
||||
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank]
|
||||
self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1)
|
||||
self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, received_grad, group_id, 1)
|
||||
|
||||
bucket_store.reset()
|
||||
|
||||
|
|
Loading…
Reference in New Issue