[Hotfix] Fix ZeRO typo #5936

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
pull/5941/head
Edenzzzz 2024-07-25 09:59:58 +08:00 committed by GitHub
parent 5fb958cc83
commit 2069472e96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 5 deletions

View File

@ -338,14 +338,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
received_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)
if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
grad_in_bucket_current_rank = bucket_store.get_grad()[bucket_store.local_rank]
self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, recieved_grad, group_id, 1)
self._update_partitoned_grad(bucket_store, grad_in_bucket_current_rank, received_grad, group_id, 1)
bucket_store.reset()