diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e3c301640..24ebae1c7 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -18,6 +18,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( FP16MixedPrecisionMixin, MixedPrecisionMixin, ) +from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8 @@ -865,19 +866,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in states.items(): if isinstance(v, torch.Tensor) and k != "step": - if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: - pinned_state_dicts[param_idx][k] = torch.empty_like( - working_param, pin_memory=True, device="cpu" - ) state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype) all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg) state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param) + if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: + pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") if pinned_state_dicts: pinned_state_dicts[param_idx][k].copy_(state_tensor) current_block[k] = pinned_state_dicts[param_idx][k] else: current_block[k] = state_tensor.cpu() - current_block_size += state_tensor.numel() + current_block_size += calculate_tensor_size(state_tensor) if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: yield ret_block, ret_block_size