|
|
@ -23,9 +23,10 @@ def broadcast_state_dict(state_dict, parallel_mode): |
|
|
|
return state_dict[0] |
|
|
|
return state_dict[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def partition_tensor_parallel_state_dict( |
|
|
|
def partition_tensor_parallel_state_dict(state_dict: OrderedDict, |
|
|
|
state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() |
|
|
|
parallel_mode: ParallelMode, |
|
|
|
): |
|
|
|
dims: dict = dict(), |
|
|
|
|
|
|
|
partition_states: dict = dict()): |
|
|
|
src_rank = gpc.get_ranks_in_group(parallel_mode)[0] |
|
|
|
src_rank = gpc.get_ranks_in_group(parallel_mode)[0] |
|
|
|
depth = gpc.get_world_size(parallel_mode) |
|
|
|
depth = gpc.get_world_size(parallel_mode) |
|
|
|
|
|
|
|
|
|
|
@ -51,11 +52,11 @@ def partition_tensor_parallel_state_dict( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_tensor_parallel_state_dict( |
|
|
|
def gather_tensor_parallel_state_dict( |
|
|
|
state_dict: OrderedDict, |
|
|
|
state_dict: OrderedDict, |
|
|
|
parallel_mode: ParallelMode, |
|
|
|
parallel_mode: ParallelMode, |
|
|
|
dims: dict = dict(), |
|
|
|
dims: dict = dict(), |
|
|
|
partition_states: dict = dict(), |
|
|
|
partition_states: dict = dict(), |
|
|
|
keep_vars: bool = False, |
|
|
|
keep_vars: bool = False, |
|
|
|
): |
|
|
|
): |
|
|
|
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] |
|
|
|
dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] |
|
|
|
depth = gpc.get_world_size(parallel_mode) |
|
|
|
depth = gpc.get_world_size(parallel_mode) |
|
|
@ -124,11 +125,8 @@ def partition_pipeline_parallel_state_dict(model, state_dict): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather_pipeline_parallel_state_dict(state_dict): |
|
|
|
def gather_pipeline_parallel_state_dict(state_dict): |
|
|
|
gathered_states = ( |
|
|
|
gathered_states = ([None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] |
|
|
|
[None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] |
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None) |
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 |
|
|
|
|
|
|
|
else None |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
dist.gather_object( |
|
|
|
dist.gather_object( |
|
|
|
state_dict, |
|
|
|
state_dict, |
|
|
|
gathered_states, |
|
|
|
gathered_states, |
|
|
@ -136,23 +134,18 @@ def gather_pipeline_parallel_state_dict(state_dict): |
|
|
|
group=gpc.get_cpu_group(ParallelMode.PIPELINE), |
|
|
|
group=gpc.get_cpu_group(ParallelMode.PIPELINE), |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
state_dict = ( |
|
|
|
state_dict = (OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) |
|
|
|
OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) |
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict()) |
|
|
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 |
|
|
|
|
|
|
|
else OrderedDict() |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return state_dict |
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint( |
|
|
|
def save_checkpoint(file, |
|
|
|
file, |
|
|
|
epoch: int, |
|
|
|
epoch: int, |
|
|
|
model: torch.nn.Module, |
|
|
|
model: torch.nn.Module, |
|
|
|
optimizer: torch.optim.Optimizer = None, |
|
|
|
optimizer: torch.optim.Optimizer = None, |
|
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, |
|
|
|
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, |
|
|
|
**kwargs): |
|
|
|
**kwargs |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, |
|
|
|
"""Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, |
|
|
|
lr_scheduler etc. into a checkpoint dictionary. |
|
|
|
lr_scheduler etc. into a checkpoint dictionary. |
|
|
|
|
|
|
|
|
|
|
@ -162,8 +155,8 @@ def save_checkpoint( |
|
|
|
epoch (int): Epoch number (indicates how many epochs have you trained this model). |
|
|
|
epoch (int): Epoch number (indicates how many epochs have you trained this model). |
|
|
|
model (:class:`torch.nn.Module`): Model to be saved. |
|
|
|
model (:class:`torch.nn.Module`): Model to be saved. |
|
|
|
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. |
|
|
|
optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. |
|
|
|
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, |
|
|
|
lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional): |
|
|
|
:class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None. |
|
|
|
lr_scheduler to be saved, defaults to None. |
|
|
|
pickle_module: module used for pickling metadata and objects |
|
|
|
pickle_module: module used for pickling metadata and objects |
|
|
|
pickle_protocol: can be specified to override the default protocol |
|
|
|
pickle_protocol: can be specified to override the default protocol |
|
|
|
""" |
|
|
|
""" |
|
|
@ -195,7 +188,7 @@ def load_checkpoint( |
|
|
|
): |
|
|
|
): |
|
|
|
"""Loads training states from a checkpoint file. |
|
|
|
"""Loads training states from a checkpoint file. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike |
|
|
|
file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike |
|
|
|
object containing a file name. |
|
|
|
object containing a file name. |
|
|
|
model (:class:`torch.nn.Module`): Model to load saved weights and buffers. |
|
|
|
model (:class:`torch.nn.Module`): Model to load saved weights and buffers. |
|
|
@ -211,9 +204,8 @@ def load_checkpoint( |
|
|
|
Raises: |
|
|
|
Raises: |
|
|
|
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated |
|
|
|
RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated |
|
|
|
""" |
|
|
|
""" |
|
|
|
state_dict = ( |
|
|
|
state_dict = (torch.load(file, map_location=torch.device("cpu")) |
|
|
|
torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None |
|
|
|
if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# model states |
|
|
|
# model states |
|
|
|
model_state = state_dict.pop("model") if state_dict is not None else dict() |
|
|
|
model_state = state_dict.pop("model") if state_dict is not None else dict() |
|
|
@ -231,11 +223,8 @@ def load_checkpoint( |
|
|
|
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) |
|
|
|
dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) |
|
|
|
if gpc.get_global_rank() == 0: |
|
|
|
if gpc.get_global_rank() == 0: |
|
|
|
all_error_msgs = list(chain.from_iterable(all_error_msgs)) |
|
|
|
all_error_msgs = list(chain.from_iterable(all_error_msgs)) |
|
|
|
raise RuntimeError( |
|
|
|
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format( |
|
|
|
"Error(s) in loading state_dict for {}:\n\t{}".format( |
|
|
|
model.__class__.__name__, "\n\t".join(all_error_msgs))) |
|
|
|
model.__class__.__name__, "\n\t".join(all_error_msgs) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
else: |
|
|
|
raise e |
|
|
|
raise e |
|
|
|
|
|
|
|
|
|
|
|