From 184a65370451452bae87d0058bba06563028c4a8 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 19 Nov 2024 11:40:42 +0800 Subject: [PATCH] [checkpointio] fix pinned state dict --- colossalai/zero/low_level/low_level_optim.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 24ebae1c7..db26269b4 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -780,19 +780,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper): zero_state = dict() device = get_accelerator().get_current_device() for param, state in self.optim.state.items(): - if pinned_state_dicts and param not in pinned_state_dicts: + if pinned_state_dicts is not None and param not in pinned_state_dicts: pinned_state_dicts[param] = {} zero_state[param] = copy.deepcopy(state) for k, v in state.items(): if isinstance(v, torch.Tensor) and k != "step": - if pinned_state_dicts and k not in pinned_state_dicts[param]: - pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu") working_param = self.master_to_working_param[id(param)] pg = self.param_to_pg[working_param] gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg) param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param) - if pinned_state_dicts: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param]: + pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu") + if pinned_state_dicts is not None: pinned_state_dicts[param][k].copy_(param_state) zero_state[param][k] = pinned_state_dicts[param][k] else: @@ -858,7 +858,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for param_idx, states in local_states.items(): current_block_size = 0 current_block = copy.deepcopy(states) - if pinned_state_dicts and param_idx not in pinned_state_dicts: + if pinned_state_dicts is not None and param_idx not in pinned_state_dicts: pinned_state_dicts[param_idx] = {} master_param = idx2master[param_idx] working_param = self.master_to_working_param[id(master_param)] @@ -869,9 +869,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) - if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: + if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]: pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") - if pinned_state_dicts: + if pinned_state_dicts is not None: pinned_state_dicts[param_idx][k].copy_(state_tensor) current_block[k] = pinned_state_dicts[param_idx][k] else: