mirror of https://github.com/hpcaitech/ColossalAI
polish checkpoint docstring (#637)
parent
055fbf5be6
commit
f5d3a9c2b0
|
@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode):
|
|||
return state_dict[0]
|
||||
|
||||
|
||||
def partition_tensor_parallel_state_dict(
|
||||
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
|
||||
):
|
||||
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
||||
parallel_mode: ParallelMode,
|
||||
dims: dict = dict(),
|
||||
partition_states: dict = dict()):
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
|
||||
|
@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
|
|||
|
||||
|
||||
def gather_pipeline_parallel_state_dict(state_dict):
|
||||
gathered_states = (
|
||||
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else None
|
||||
)
|
||||
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
|
||||
dist.gather_object(
|
||||
state_dict,
|
||||
gathered_states,
|
||||
|
@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict):
|
|||
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
|
||||
)
|
||||
|
||||
state_dict = (
|
||||
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
else OrderedDict()
|
||||
)
|
||||
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
||||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
file,
|
||||
def save_checkpoint(file,
|
||||
epoch: int,
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs):
|
||||
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
|
||||
lr_scheduler etc. into a checkpoint dictionary.
|
||||
|
||||
|
@ -162,8 +155,8 @@ def save_checkpoint(
|
|||
epoch (int): Epoch number (indicates how many epochs have you trained this model).
|
||||
model (:class:`torch.nn.Module`): Model to be saved.
|
||||
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
|
||||
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
|
||||
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
|
||||
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
|
||||
lr_scheduler to be saved, defaults to None.
|
||||
pickle_module: module used for pickling metadata and objects
|
||||
pickle_protocol: can be specified to override the default protocol
|
||||
"""
|
||||
|
@ -211,9 +204,8 @@ def load_checkpoint(
|
|||
Raises:
|
||||
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
|
||||
"""
|
||||
state_dict = (
|
||||
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
|
||||
)
|
||||
state_dict = (torch.load(file, map_location=torch.device("cpu"))
|
||||
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
|
||||
|
||||
# model states
|
||||
model_state = state_dict.pop("model") if state_dict is not None else dict()
|
||||
|
@ -231,11 +223,8 @@ def load_checkpoint(
|
|||
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
|
||||
if gpc.get_global_rank() == 0:
|
||||
all_error_msgs = list(chain.from_iterable(all_error_msgs))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)
|
||||
)
|
||||
)
|
||||
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
model.__class__.__name__, "\n\t".join(all_error_msgs)))
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
|
Loading…
Reference in New Issue