Browse Source

polish checkpoint docstring (#637)

pull/633/head
ver217 3 years ago committed by GitHub
parent
commit
f5d3a9c2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 63
      colossalai/utils/checkpointing.py

63
colossalai/utils/checkpointing.py

@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode):
return state_dict[0] return state_dict[0]
def partition_tensor_parallel_state_dict( def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() parallel_mode: ParallelMode,
): dims: dict = dict(),
partition_states: dict = dict()):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0] src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
@ -51,11 +52,11 @@ def partition_tensor_parallel_state_dict(
def gather_tensor_parallel_state_dict( def gather_tensor_parallel_state_dict(
state_dict: OrderedDict, state_dict: OrderedDict,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
dims: dict = dict(), dims: dict = dict(),
partition_states: dict = dict(), partition_states: dict = dict(),
keep_vars: bool = False, keep_vars: bool = False,
): ):
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
def gather_pipeline_parallel_state_dict(state_dict): def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = ( gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
)
dist.gather_object( dist.gather_object(
state_dict, state_dict,
gathered_states, gathered_states,
@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict):
group=gpc.get_cpu_group(ParallelMode.PIPELINE), group=gpc.get_cpu_group(ParallelMode.PIPELINE),
) )
state_dict = ( state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
)
return state_dict return state_dict
def save_checkpoint( def save_checkpoint(file,
file, epoch: int,
epoch: int, model: torch.nn.Module,
model: torch.nn.Module, optimizer: torch.optim.Optimizer = None,
optimizer: torch.optim.Optimizer = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, **kwargs):
**kwargs
):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary. lr_scheduler etc. into a checkpoint dictionary.
@ -162,8 +155,8 @@ def save_checkpoint(
epoch (int): Epoch number (indicates how many epochs have you trained this model). epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved. model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None. lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol pickle_protocol: can be specified to override the default protocol
""" """
@ -195,7 +188,7 @@ def load_checkpoint(
): ):
"""Loads training states from a checkpoint file. """Loads training states from a checkpoint file.
Args: Args:
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
object containing a file name. object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers. model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
@ -211,9 +204,8 @@ def load_checkpoint(
Raises: Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
""" """
state_dict = ( state_dict = (torch.load(file, map_location=torch.device("cpu"))
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
)
# model states # model states
model_state = state_dict.pop("model") if state_dict is not None else dict() model_state = state_dict.pop("model") if state_dict is not None else dict()
@ -231,11 +223,8 @@ def load_checkpoint(
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs)) all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError( raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
"Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(all_error_msgs)))
model.__class__.__name__, "\n\t".join(all_error_msgs)
)
)
else: else:
raise e raise e

Loading…
Cancel
Save