from collections import OrderedDict from itertools import chain import torch import torch.distributed as dist from colossalai.communication.collective import scatter_object_list from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX from .common import is_using_pp __all__ = ["save_checkpoint", "load_checkpoint"] def broadcast_state_dict(state_dict, parallel_mode): state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict] src_rank = gpc.get_ranks_in_group(parallel_mode)[0] dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) return state_dict[0] def partition_tensor_parallel_state_dict( state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() ): src_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) if gpc.get_local_rank(parallel_mode) == 0: partitioned_state_list = [dict() for _ in range(depth)] for key in list(state_dict.keys()): param = state_dict.pop(key) dim = dims.get(key, 0) do_partition = partition_states.get(key, True) if do_partition: param = torch.chunk(param, depth, dim=dim) for i, p in enumerate(partitioned_state_list): p[key] = param[i] if do_partition else param else: partitioned_state_list = [None for _ in range(depth)] partitioned_state = [None] scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) return partitioned_state[0] def gather_tensor_parallel_state_dict( state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict(), keep_vars: bool = False, ): dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] depth = gpc.get_world_size(parallel_mode) for key in list(state_dict.keys()): param = state_dict.pop(key) param = param if keep_vars else param.detach() dim = dims.get(key, 0) do_partition = partition_states.get(key, True) if do_partition: temp = param.transpose(0, dim).contiguous() gather_list = None if gpc.get_local_rank(parallel_mode) == 0: shape = list(param.shape) shape[0], shape[dim] = shape[dim], shape[0] shape[0] *= depth param = torch.empty(shape, dtype=param.dtype, device=param.device) gather_list = list(torch.chunk(param, depth, dim=0)) dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode)) param = torch.transpose(param, 0, dim) # update params in state_dict only on local rank 0 if gpc.get_local_rank(parallel_mode) == 0: state_dict[key] = param return state_dict def _send_state_dict(state_dict, dst, parallel_mode): state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict) dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode)) dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode)) def _recv_state_dict(src, parallel_mode): state_size = torch.tensor([0], dtype=torch.long) dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode)) state_tensor = torch.empty(state_size.item(), dtype=torch.uint8) dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode)) state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size) return state_dict def partition_pipeline_parallel_state_dict(model, state_dict): pipeline_state = OrderedDict() if gpc.get_local_rank(ParallelMode.TENSOR) == 0: # receive all states from prev stage if not gpc.is_first_rank(ParallelMode.PIPELINE): state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) # move states to output for name, _ in model.named_parameters(recurse=True): if name in state_dict: pipeline_state[name] = state_dict.pop(name) for name, _ in model.named_buffers(recurse=True): if name in state_dict: pipeline_state[name] = state_dict.pop(name) for name, _ in model.named_modules(): extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX if extra_state_key in state_dict: pipeline_state[extra_state_key] = state_dict.pop(extra_state_key) # send rest states to next stage if not gpc.is_last_rank(ParallelMode.PIPELINE): _send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) return pipeline_state def gather_pipeline_parallel_state_dict(state_dict): gathered_states = ( [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else None ) dist.gather_object( state_dict, gathered_states, dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0], group=gpc.get_cpu_group(ParallelMode.PIPELINE), ) state_dict = ( OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 else OrderedDict() ) return state_dict def save_checkpoint( file, epoch: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, **kwargs ): """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler etc. into a checkpoint dictionary. Args: file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a file name. epoch (int): Epoch number (indicates how many epochs have you trained this model). model (:class:`torch.nn.Module`): Model to be saved. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, :class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None. pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol """ # ckpt container checkpoint = {"epoch": epoch} model_state = model.state_dict() if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: model_state = gather_pipeline_parallel_state_dict(model_state) if gpc.get_global_rank() == 0: checkpoint["model"] = model_state # if optimizer is not None: # checkpoint['optimizer'] = optimizer.state_dict() # if lr_scheduler is not None: # checkpoint['lr_scheduler'] = lr_scheduler.state_dict() torch.save(checkpoint, file, **kwargs) def load_checkpoint( file, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, strict: bool = True, ): """Loads training states from a checkpoint file. Args: file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike object containing a file name. model (:class:`torch.nn.Module`): Model to load saved weights and buffers. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional): lr_scheduler to recuperate, defaults to None. strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model, defaults to True. Returns: int: The saved epoch number. Raises: RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ state_dict = ( torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None ) # model states model_state = state_dict.pop("model") if state_dict is not None else dict() # pipeline if is_using_pp(): model_state = partition_pipeline_parallel_state_dict(model, model_state) try: model.load_state_dict(model_state, strict=strict) except RuntimeError as e: error_msgs = str(e) if error_msgs.startswith("Error(s) in loading state_dict for "): error_msgs = error_msgs.split("\n\t")[1:] dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0] all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))] dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) if gpc.get_global_rank() == 0: all_error_msgs = list(chain.from_iterable(all_error_msgs)) raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(all_error_msgs) ) ) else: raise e # broadcast the rest states state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL) # # optimizer states # if optimizer is not None and 'optimizer' in state_dict: # optimizer.load_state_dict(state_dict['optimizer']) # # lr scheduler states # if lr_scheduler is not None and 'lr_scheduler' in state_dict: # lr_scheduler.load_state_dict(state_dict['lr_scheduler']) # last epoch last_epoch = state_dict.pop("epoch", -1) return last_epoch