mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] fix size compute
parent
eb69e640e5
commit
5fa657f0a1
|
@ -18,6 +18,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
FP16MixedPrecisionMixin,
|
FP16MixedPrecisionMixin,
|
||||||
MixedPrecisionMixin,
|
MixedPrecisionMixin,
|
||||||
)
|
)
|
||||||
|
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
||||||
|
@ -865,19 +866,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
|
|
||||||
for k, v in states.items():
|
for k, v in states.items():
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
|
|
||||||
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
|
||||||
working_param, pin_memory=True, device="cpu"
|
|
||||||
)
|
|
||||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
||||||
|
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
|
||||||
|
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
|
||||||
if pinned_state_dicts:
|
if pinned_state_dicts:
|
||||||
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
||||||
current_block[k] = pinned_state_dicts[param_idx][k]
|
current_block[k] = pinned_state_dicts[param_idx][k]
|
||||||
else:
|
else:
|
||||||
current_block[k] = state_tensor.cpu()
|
current_block[k] = state_tensor.cpu()
|
||||||
current_block_size += state_tensor.numel()
|
current_block_size += calculate_tensor_size(state_tensor)
|
||||||
|
|
||||||
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
||||||
yield ret_block, ret_block_size
|
yield ret_block, ret_block_size
|
||||||
|
|
Loading…
Reference in New Issue