polish checkpoint docstring (#637)

pull/633/head
ver217 2022-04-02 13:34:33 +08:00 committed by GitHub
parent 055fbf5be6
commit f5d3a9c2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 37 deletions

View File

@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode):
return state_dict[0] return state_dict[0]
def partition_tensor_parallel_state_dict( def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() parallel_mode: ParallelMode,
): dims: dict = dict(),
partition_states: dict = dict()):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0] src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
def gather_pipeline_parallel_state_dict(state_dict): def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = ( gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
)
dist.gather_object( dist.gather_object(
state_dict, state_dict,
gathered_states, gathered_states,
@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict):
group=gpc.get_cpu_group(ParallelMode.PIPELINE), group=gpc.get_cpu_group(ParallelMode.PIPELINE),
) )
state_dict = ( state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
)
return state_dict return state_dict
def save_checkpoint( def save_checkpoint(file,
file,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs **kwargs):
):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary. lr_scheduler etc. into a checkpoint dictionary.
@ -162,8 +155,8 @@ def save_checkpoint(
epoch (int): Epoch number (indicates how many epochs have you trained this model). epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved. model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None. lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol pickle_protocol: can be specified to override the default protocol
""" """
@ -211,9 +204,8 @@ def load_checkpoint(
Raises: Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
""" """
state_dict = ( state_dict = (torch.load(file, map_location=torch.device("cpu"))
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
)
# model states # model states
model_state = state_dict.pop("model") if state_dict is not None else dict() model_state = state_dict.pop("model") if state_dict is not None else dict()
@ -231,11 +223,8 @@ def load_checkpoint(
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs)) all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError( raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
"Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(all_error_msgs)))
model.__class__.__name__, "\n\t".join(all_error_msgs)
)
)
else: else:
raise e raise e