mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] optimize zero optim checkpoint io (#4591)
* [zero] update checkpoint io to save memory * [checkpointio] add device map to save memorypull/4602/head
parent
cfa607080f
commit
63ecafb1fb
|
@ -17,8 +17,13 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||||
from colossalai.checkpoint_io.utils import (
|
from colossalai.checkpoint_io.utils import (
|
||||||
get_optimizer_base_filenames,
|
get_optimizer_base_filenames,
|
||||||
get_shard_filename,
|
get_shard_filename,
|
||||||
|
load_param_groups_into_optimizer,
|
||||||
|
load_shard_state_dict,
|
||||||
|
load_states_into_optimizer,
|
||||||
save_param_groups,
|
save_param_groups,
|
||||||
save_state_dict,
|
save_state_dict,
|
||||||
|
sharded_optimizer_loading_epilogue,
|
||||||
|
unwrap_optimizer,
|
||||||
)
|
)
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -126,9 +131,26 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
index_file_path (str): Path to the index file
|
index_file_path (str): Path to the index file
|
||||||
prefix (str): Not used.
|
prefix (str): Not used.
|
||||||
"""
|
"""
|
||||||
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
|
# If optimizer is wrapped, unwrap it.
|
||||||
current_rank_state_dict = optimizer.optim.state_dict()['state']
|
if isinstance(optimizer, OptimizerWrapper):
|
||||||
for param_idx, state in current_rank_state_dict.items():
|
optimizer = unwrap_optimizer(optimizer)
|
||||||
|
|
||||||
|
# Read checkpoint index file.
|
||||||
|
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||||
|
|
||||||
|
# Load param_groups
|
||||||
|
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||||
|
if param_group_path is None:
|
||||||
|
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
|
||||||
|
Lacking param group file under current directory.')
|
||||||
|
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)
|
||||||
|
|
||||||
|
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||||
|
|
||||||
|
for shard_file in checkpoint_files:
|
||||||
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
|
# shard state dict
|
||||||
|
for param_idx, state in state_dict.items():
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if isinstance(v, torch.Tensor) and k != 'step':
|
if isinstance(v, torch.Tensor) and k != 'step':
|
||||||
padding_size = (self.coordinator.world_size -
|
padding_size = (self.coordinator.world_size -
|
||||||
|
@ -138,7 +160,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||||
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
|
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||||
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
|
|
||||||
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroModel(ModelWrapper):
|
class LowLevelZeroModel(ModelWrapper):
|
||||||
|
|
|
@ -78,8 +78,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
for shard_file in checkpoint_files:
|
for shard_file in checkpoint_files:
|
||||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||||
del state_dict
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer)
|
sharded_optimizer_loading_epilogue(optimizer)
|
||||||
|
|
||||||
|
|
|
@ -237,7 +237,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||||
return safe_load_file(checkpoint_file)
|
return safe_load_file(checkpoint_file)
|
||||||
else:
|
else:
|
||||||
return torch.load(checkpoint_file)
|
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_into_model(model: nn.Module,
|
def load_state_dict_into_model(model: nn.Module,
|
||||||
|
@ -297,7 +297,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
||||||
|
|
||||||
# Load list of param_groups from given file path.
|
# Load list of param_groups from given file path.
|
||||||
# The params in saved_groups are in the form of integer indices.
|
# The params in saved_groups are in the form of integer indices.
|
||||||
saved_groups = torch.load(param_group_path)
|
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
|
||||||
if not isinstance(saved_groups, List):
|
if not isinstance(saved_groups, List):
|
||||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||||
|
|
||||||
|
@ -608,7 +608,7 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# load with torch
|
# load with torch
|
||||||
return torch.load(checkpoint_file_path)
|
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
|
||||||
|
|
||||||
|
|
||||||
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||||
|
|
|
@ -553,11 +553,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
if padding_size > 0:
|
if padding_size > 0:
|
||||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||||
v_list = v.split(v.numel() // self._world_size)
|
v_list = v.split(v.numel() // self._world_size)
|
||||||
device = 'cpu' if self._cpu_offload else 'cuda'
|
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone()
|
||||||
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach()
|
|
||||||
|
|
||||||
self.optim.load_state_dict(zero_state_dict)
|
self.optim.load_state_dict(zero_state_dict)
|
||||||
zero_state_dict = dict()
|
|
||||||
|
|
||||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||||
|
|
Loading…
Reference in New Issue