|
|
|
@ -780,19 +780,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
zero_state = dict() |
|
|
|
|
device = get_accelerator().get_current_device() |
|
|
|
|
for param, state in self.optim.state.items(): |
|
|
|
|
if pinned_state_dicts and param not in pinned_state_dicts: |
|
|
|
|
if pinned_state_dicts is not None and param not in pinned_state_dicts: |
|
|
|
|
pinned_state_dicts[param] = {} |
|
|
|
|
zero_state[param] = copy.deepcopy(state) |
|
|
|
|
for k, v in state.items(): |
|
|
|
|
if isinstance(v, torch.Tensor) and k != "step": |
|
|
|
|
if pinned_state_dicts and k not in pinned_state_dicts[param]: |
|
|
|
|
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu") |
|
|
|
|
working_param = self.master_to_working_param[id(param)] |
|
|
|
|
pg = self.param_to_pg[working_param] |
|
|
|
|
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) |
|
|
|
|
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) |
|
|
|
|
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) |
|
|
|
|
if pinned_state_dicts: |
|
|
|
|
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: |
|
|
|
|
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") |
|
|
|
|
if pinned_state_dicts is not None: |
|
|
|
|
pinned_state_dicts[param][k].copy_(param_state) |
|
|
|
|
zero_state[param][k] = pinned_state_dicts[param][k] |
|
|
|
|
else: |
|
|
|
@ -858,7 +858,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
for param_idx, states in local_states.items(): |
|
|
|
|
current_block_size = 0 |
|
|
|
|
current_block = copy.deepcopy(states) |
|
|
|
|
if pinned_state_dicts and param_idx not in pinned_state_dicts: |
|
|
|
|
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts: |
|
|
|
|
pinned_state_dicts[param_idx] = {} |
|
|
|
|
master_param = idx2master[param_idx] |
|
|
|
|
working_param = self.master_to_working_param[id(master_param)] |
|
|
|
@ -869,9 +869,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|
|
|
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) |
|
|
|
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) |
|
|
|
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) |
|
|
|
|
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: |
|
|
|
|
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: |
|
|
|
|
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") |
|
|
|
|
if pinned_state_dicts: |
|
|
|
|
if pinned_state_dicts is not None: |
|
|
|
|
pinned_state_dicts[param_idx][k].copy_(state_tensor) |
|
|
|
|
current_block[k] = pinned_state_dicts[param_idx][k] |
|
|
|
|
else: |
|
|
|
|