|
|
@ -880,7 +880,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
|
|
|
grad_flat = torch.empty((grad_store.world_size, *grad.shape), dtype=grad.dtype, device=grad.device)
|
|
|
|
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
|
|
|
dist.all_gather_into_tensor(grad_flat, grad, group=grad_store.torch_pg)
|
|
|
|
return grad_flat[: working_param.numel()].reshape_as(working_param)
|
|
|
|
return grad_flat.view(-1)[: working_param.numel()].view_as(working_param)
|
|
|
|
|
|
|
|
|
|
|
|
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
|
|
|
def get_working_grads_by_group_id(self, group_id: int) -> List[Tensor]:
|
|
|
|
working_grads = []
|
|
|
|
working_grads = []
|
|
|
|