|
|
|
@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode):
|
|
|
|
|
return state_dict[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def partition_tensor_parallel_state_dict(
|
|
|
|
|
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict()
|
|
|
|
|
):
|
|
|
|
|
def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
|
|
|
|
parallel_mode: ParallelMode,
|
|
|
|
|
dims: dict = dict(),
|
|
|
|
|
partition_states: dict = dict()):
|
|
|
|
|
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
|
|
|
|
depth = gpc.get_world_size(parallel_mode)
|
|
|
|
|
|
|
|
|
@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_pipeline_parallel_state_dict(state_dict):
|
|
|
|
|
gathered_states = (
|
|
|
|
|
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
|
|
|
|
else None
|
|
|
|
|
)
|
|
|
|
|
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))]
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None)
|
|
|
|
|
dist.gather_object(
|
|
|
|
|
state_dict,
|
|
|
|
|
gathered_states,
|
|
|
|
@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict):
|
|
|
|
|
group=gpc.get_cpu_group(ParallelMode.PIPELINE),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
state_dict = (
|
|
|
|
|
OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
|
|
|
|
else OrderedDict()
|
|
|
|
|
)
|
|
|
|
|
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states))
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict())
|
|
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(
|
|
|
|
|
file,
|
|
|
|
|
def save_checkpoint(file,
|
|
|
|
|
epoch: int,
|
|
|
|
|
model: torch.nn.Module,
|
|
|
|
|
optimizer: torch.optim.Optimizer = None,
|
|
|
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
**kwargs):
|
|
|
|
|
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer,
|
|
|
|
|
lr_scheduler etc. into a checkpoint dictionary.
|
|
|
|
|
|
|
|
|
@ -162,8 +155,8 @@ def save_checkpoint(
|
|
|
|
|
epoch (int): Epoch number (indicates how many epochs have you trained this model).
|
|
|
|
|
model (:class:`torch.nn.Module`): Model to be saved.
|
|
|
|
|
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved.
|
|
|
|
|
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`,
|
|
|
|
|
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None.
|
|
|
|
|
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional):
|
|
|
|
|
lr_scheduler to be saved, defaults to None.
|
|
|
|
|
pickle_module: module used for pickling metadata and objects
|
|
|
|
|
pickle_protocol: can be specified to override the default protocol
|
|
|
|
|
"""
|
|
|
|
@ -211,9 +204,8 @@ def load_checkpoint(
|
|
|
|
|
Raises:
|
|
|
|
|
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated
|
|
|
|
|
"""
|
|
|
|
|
state_dict = (
|
|
|
|
|
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None
|
|
|
|
|
)
|
|
|
|
|
state_dict = (torch.load(file, map_location=torch.device("cpu"))
|
|
|
|
|
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None)
|
|
|
|
|
|
|
|
|
|
# model states
|
|
|
|
|
model_state = state_dict.pop("model") if state_dict is not None else dict()
|
|
|
|
@ -231,11 +223,8 @@ def load_checkpoint(
|
|
|
|
|
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL))
|
|
|
|
|
if gpc.get_global_rank() == 0:
|
|
|
|
|
all_error_msgs = list(chain.from_iterable(all_error_msgs))
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
|
|
|
|
model.__class__.__name__, "\n\t".join(all_error_msgs)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(
|
|
|
|
|
model.__class__.__name__, "\n\t".join(all_error_msgs)))
|
|
|
|
|
else:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|