Browse Source

polish checkpoint docstring (#637)

pull/633/head
ver217 3 years ago committed by GitHub
parent
commit
f5d3a9c2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 63
      colossalai/utils/checkpointing.py

63
colossalai/utils/checkpointing.py

@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode):
return state_dict[0]
def partition_tensor_parallel_state_dict(
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
):
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict()):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
@ -51,11 +52,11 @@ def partition_tensor_parallel_state_dict(
def gather_tensor_parallel_state_dict(
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
state_dict: OrderedDict,
parallel_mode: ParallelMode,
dims: dict = dict(),
partition_states: dict = dict(),
keep_vars: bool = False,
):
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
def gather_pipeline_parallel_state_dict(state_dict):
gathered_states = (
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else None
)
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
dist.gather_object(
state_dict,
gathered_states,
@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict):
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
)
state_dict = (
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
else OrderedDict()
)
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
return state_dict
def save_checkpoint(
file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs
):
def save_checkpoint(file,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
lr_scheduler etc. into a checkpoint dictionary.
@ -162,8 +155,8 @@ def save_checkpoint(
epoch (int): Epoch number (indicates how many epochs have you trained this model).
model (:class:`torch.nn.Module`): Model to be saved.
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
lr_scheduler to be saved, defaults to None.
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
@ -195,7 +188,7 @@ def load_checkpoint(
):
"""Loads training states from a checkpoint file.
Args:
Args:
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike
object containing a file name.
model (:class:`torch.nn.Module`): Model to load saved weights and buffers.
@ -211,9 +204,8 @@ def load_checkpoint(
Raises:
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
"""
state_dict = (
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
)
state_dict = (torch.load(file, map_location=torch.device("cpu"))
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
# model states
model_state = state_dict.pop("model") if state_dict is not None else dict()
@ -231,11 +223,8 @@ def load_checkpoint(
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
if gpc.get_global_rank() == 0:
all_error_msgs = list(chain.from_iterable(all_error_msgs))
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)
)
)
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(all_error_msgs)))
else:
raise e

Loading…
Cancel
Save