From f5d3a9c2b0ee1b1339d18d49a17de163bd8c9583 Mon Sep 17 00:00:00 2001 From: ver217 Date: Sat, 2 Apr 2022 13:34:33 +0800 Subject: [PATCH] polish checkpoint docstring (#637) --- colossalai/utils/checkpointing.py | 63 +++++++++++++------------------ 1 file changed, 26 insertions(+), 37 deletions(-) diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index 34eaa2ea0..6341e907c 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode): return state_dict[0] -def partition_tensor_parallel_state_dict( - state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() -): +def partition_tensor_parallel_state_dict(state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict()): src_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) @@ -51,11 +52,11 @@ def partition_tensor_parallel_state_dict( def gather_tensor_parallel_state_dict( - state_dict: OrderedDict, - parallel_mode: ParallelMode, - dims: dict = dict(), - partition_states: dict = dict(), - keep_vars: bool = False, + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, ): dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) @@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict): def gather_pipeline_parallel_state_dict(state_dict): - gathered_states = ( - [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 - else None - ) + gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) dist.gather_object( state_dict, gathered_states, @@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict): group=gpc.get_cpu_group(ParallelMode.PIPELINE), ) - state_dict = ( - OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) - if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 - else OrderedDict() - ) + state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) return state_dict -def save_checkpoint( - file, - epoch: int, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer = None, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - **kwargs -): +def save_checkpoint(file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs): """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary. @@ -162,8 +155,8 @@ def save_checkpoint( epoch (int): Epoch number (indicates how many epochs have you trained this model). model (:class:`torch.nn.Module`): Model to be saved. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. - lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, - :class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None. + lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional): + lr_scheduler to be saved, defaults to None. pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol """ @@ -195,7 +188,7 @@ def load_checkpoint( ): """Loads training states from a checkpoint file. - Args: + Args: file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike object containing a file name. model (:class:`torch.nn.Module`): Model to load saved weights and buffers. @@ -211,9 +204,8 @@ def load_checkpoint( Raises: RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ - state_dict = ( - torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None - ) + state_dict = (torch.load(file, map_location=torch.device("cpu")) + if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) # model states model_state = state_dict.pop("model") if state_dict is not None else dict() @@ -231,11 +223,8 @@ def load_checkpoint( dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) if gpc.get_global_rank() == 0: all_error_msgs = list(chain.from_iterable(all_error_msgs)) - raise RuntimeError( - "Error(s) in loading state_dict for {}:\n\t{}".format( - model.__class__.__name__, "\n\t".join(all_error_msgs) - ) - ) + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs))) else: raise e