diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py index ddef565c5..635eebd89 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -1,205 +1,617 @@ +import copy import logging import os from pathlib import Path +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from colossalai.checkpoint_io import CheckpointIndexFile -from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model -from colossalai.moe import MoECheckpintIO -from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor +from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO +from colossalai.checkpoint_io.index_file import CheckpointIndexFile +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + load_shard_state_dict, + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, +) +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.moe import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" -class MixtralMoECheckpointIO(MoECheckpintIO): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: + super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose) + moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size] + self.ep_group = moe_info.ep_group + self.ep_size = moe_info.ep_size + self.ep_rank = moe_info.ep_rank + self.real_dp_rank = moe_info.dp_rank - @torch.no_grad() - def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict: + @staticmethod + def _model_sharder( + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + param_name_pattern: Optional[str] = None, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + if param_name_pattern is not None and param_name_pattern not in name: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: """ - Preprocess state_dict before loading and slice the state_dict of MOE tensors. - """ - model_param_dict = dict(model.named_parameters()) - for name, param in list(state_dict.items()): - if ".gate.weight" in name: - new_name = "module." + name.replace(".gate.weight", ".gate_weight") - state_dict[new_name] = state_dict.pop(name) - elif ".experts." in name: - # if is moe tensor - # in our moe module, expert is cat as one tensor - # but mixtral's experts is not cat - # we will insert the loaded expert into the position of cat tensor + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" - # get model param - str_idx = name.index(".experts.") - expert_idx = int(name.split(".")[-3]) - if ".w1." in name: - model_param_name = name.replace(name[str_idx:], ".experts.wi_gate") - elif ".w2." in name: - model_param_name = name.replace(name[str_idx:], ".experts.wo") - elif ".w3." in name: - model_param_name = name.replace(name[str_idx:], ".experts.wi_up") - model_param_name = "module." + model_param_name - # skip for pipeline - if model_param_name not in model_param_dict: - continue - model_param = model_param_dict[model_param_name] - assert is_moe_tensor(model_param) - # get expert range - ep_rank = get_ep_rank(model_param) - ep_size = get_ep_size(model_param) - expert_num = 8 // ep_size - expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num)) - # insert new param - if expert_idx in expert_range: - new_param = model_param - new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1) - state_dict[model_param_name] = new_param - state_dict.pop(name) - else: - new_name = "module." + name - state_dict[new_name] = state_dict.pop(name) - - dist.barrier() - return state_dict - - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): - """ - Load sharded model with the given path to index file of checkpoint folder. Args: - model (nn.Module): The model to be loaded. - checkpoint_index_file (str): Path to the index file of checkpointing folder. - strict (bool, optional): For name matching during loading state_dict. Defaults to False. - This argument should be manually set to False since params on same device might be stored in different files. + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ - # Check whether the checkpoint uses safetensors. - use_safetensors = False - if "safetensors" in checkpoint_index_file.name: - use_safetensors = True + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() - if use_safetensors and not is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + if self.real_dp_rank != 0: + return + + # ep_rank 0 saves all the parameters and buffers. + # other ep_ranks save only experts + ep_param_pattern = "experts." if self.ep_rank != 0 else None + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern + ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + weights_name = weights_name.replace( + ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors" + ) + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + dist.barrier(self.ep_group) + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) + rmtree(tmp_index_file_folder) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + @staticmethod + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + is_moe_param: bool, + device: torch.device = torch.device("cpu"), + ) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # First gather Zero shards. + if use_zero and not is_moe_param: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().to(device) + + return state_ + + @staticmethod + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + size_per_shard: int = 1024, + only_moe_param: bool = False, + ): + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + is_moe_param=is_moe_tensor(working_param), + ) + + if only_moe_param and not is_moe_tensor(working_param): + continue + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.real_dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + size_per_shard=size_per_shard, + only_moe_param=self.ep_rank != 0, + ) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.real_dp_rank == 0 and self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) + + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + dist.barrier(self.ep_group) + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_root_path = ckpt_index_file.root_path weight_map = ckpt_index_file.weight_map - strict = False + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int - # Load params & buffers to model. + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + # ep param groups + if len(optimizer.optim.param_groups) == len(saved_groups) + 1: + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. # Keep a record of loaded files so that file will not be repeatedly loaded. loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] - def _load(name: str): - if name not in weight_map: - raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") - filename = weight_map[name] + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue - # If this param/buffer has been loaded before, directly return. - if filename in loaded_file: - return + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) - file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors) - state_dict = self.pre_load_model(model, state_dict) - missing_keys = [] - - load_state_dict_into_model( - model, - state_dict, - missing_keys=missing_keys, - strict=strict, - load_sub_module=True, + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + is_moe_param=is_moe_tensor(working_param), ) - loaded_file.add(filename) + optimizer.optim.state[param] = sharded_state - # Load parameters. - for name, _ in model.named_parameters(): - name = name.replace("module.", "") - name = name.replace(".gate_weight", ".gate.weight") - if ".experts.wi_gate" in name: - for i in range(8): - new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") - _load(new_name) - elif ".experts.wi_up" in name: - for i in range(8): - new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") - _load(new_name) - elif ".experts.wo" in name: - for i in range(8): - new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") - _load(new_name) - else: - _load(name) + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - if self.verbose: - logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + is_moe_param: bool, + ) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. - @torch.no_grad() - def pre_save_model(self, model: nn.Module) -> dict: - torch.cuda.empty_cache() - state_dict = model.state_dict() - for name, param in list(model.named_parameters()): - if ".gate_weight" in name: - new_name = name.replace(".gate_weight", ".gate.weight") - state_dict[new_name] = state_dict.pop(name).cpu() - elif ".experts." in name: - ep_group = get_ep_group(param) - ep_rank = get_ep_rank(param) - ep_size = get_ep_size(param) - dp_rank = get_dp_rank(param) + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. - if dp_rank == 0: - param = param.data.cuda() - all_param = [torch.zeros_like(param) for _ in range(ep_size)] - # gather param from every ep rank - dist.all_gather(all_param, param, group=ep_group) - if ep_rank == 0: - all_param = torch.cat(all_param, dim=0) - assert all_param.shape[0] == 8 - for i in range(8): - if ".wi_gate" in name: - new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight") - elif ".wi_up" in name: - new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight") - elif ".wo" in name: - new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight") - new_name = new_name.replace("module.", "") - new_param = all_param[i].transpose(-1, -2) - state_dict[new_name] = new_param.cpu() - state_dict.pop(name) - else: - state_dict[name] = param.cpu() + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) - for name, param in list(state_dict.items()): - new_name = name.replace("module.", "") - state_dict[new_name] = state_dict.pop(name) - - torch.cuda.empty_cache() - if self.pp_size > 1: - if self.dp_rank == 0: - # gather state_dict from every pp rank - # because ckpt is large, we split it into 10 parts - # and gather them one by one - new_state_dict = {} - state_dict_keys = list(state_dict.keys()) - gap_key_num = min(30, len(state_dict_keys)) - gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num - for i in range(gap_key_num): - cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys] - cur_state_dict = {} - for k in cur_keys: - cur_state_dict[k] = state_dict[k] - out = [None for _ in range(self.pp_size)] - dist.all_gather_object(out, cur_state_dict, group=self.pp_group) - if self.pp_rank == 0: - for o in out: - for k, v in o.items(): - new_state_dict[k] = v.cpu() - state_dict = new_state_dict - dist.barrier() - return state_dict + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero and not is_moe_param: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False): + raise NotImplementedError diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py index e395c8578..a2b78a2bd 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -1,80 +1,92 @@ import torch -import torch.nn as nn -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock +import torch.distributed as dist +import torch.nn.functional as F +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from colossalai.lazy import LazyInitContext -from colossalai.moe import SparseMLP +from colossalai.moe import MOE_MANAGER +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_info -class MixtralSparseMLP: - r""" - This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. - """ +class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + super().__init__(config) + self.setup_ep() - def __init__(self) -> None: - raise NotImplementedError( - "FusedLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex." - ) + def setup_ep(self): + _, moe_info = MOE_MANAGER.get_info(self.num_experts) + ep_group = moe_info.ep_group + self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 + self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + for p in self.experts.parameters(): + set_moe_tensor_info(p, moe_info) @staticmethod - def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module: - r""" - Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex, - and optionally marking parameters for gradient aggregation. + def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + LazyInitContext.materialize(module) + module.__class__ = EPMixtralSparseMoeBlock + module.setup_ep() + return module - Args: - module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted. - sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism. + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) - Returns: - nn.Module: Union[FastLayerNorm, FusedLayerNorm]. + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) - Raises: - AssertionError: If the provided module is not an instance of nn.LayerNorm. - """ - with torch.no_grad(): - LazyInitContext.materialize(module) + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) - # get the attributes of the module - moe_kwargs = dict( - num_experts=8, - hidden_size=module.hidden_dim, - intermediate_size=module.ffn_dim, - router_top_k=module.top_k, - router_norm=True, - router_loss=False, - # router_capacity_factor_train= - # router_capacity_factor_eval= - mlp_activation="silu", - mlp_gated=True, - # enable_load_balance= - # load_balance_tolerance= - # load_balance_beam_width= - # load_balance_group_swap_factor= - enable_kernel=enable_kernel, - # enable_comm_overlap= - # enable_hierarchical_comm= - return_gate_logits=True, - ) - dtype = module.gate.weight.dtype - device = module.gate.weight.device - sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device) - - return sparse_mlp - - -def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module: - """ - Reverse the replace layer operation - - Args: - module (torch.nn.Module): The object of layer to shard - """ - if isinstance(model, MixtralDecoderLayer): - model.block_sparse_moe = MixtralSparseMLP.from_native_module( - model.block_sparse_moe, enable_kernel=enable_kernel + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + # compute expert output + output_states = MoeInGradScaler.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) + output_states = expert.w2(output_states) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) + split_states = expert.w2(split_states) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device ) - else: - for _, child in model.named_children(): - replace_moe_layer(child, enable_kernel) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 2f6021f2d..734695278 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -20,6 +20,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.shard import ShardConfig +from .mixtral_layer import EPMixtralSparseMoeBlock + __all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] @@ -51,6 +53,18 @@ class MixtralPolicy(Policy): if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + # optimization configuration if self.shard_config.enable_fused_normalization: self.append_or_create_submodule_replacement( diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py index 70b827264..a2a0a7e78 100644 --- a/applications/ColossalMoE/colossal_moe/utils.py +++ b/applications/ColossalMoE/colossal_moe/utils.py @@ -3,7 +3,6 @@ import os from typing import Any, Dict, Tuple, Union import torch -from huggingface_hub import snapshot_download from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer @@ -15,23 +14,6 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -@torch.no_grad() -def load_model(ckpt_path: str, model, booster: Booster, optimizer=None): - # pytorch ckpt - if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")): - ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json") - # saved ckpt - elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): - ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") - # download - else: - ckpt_path = snapshot_download(ckpt_path) - booster.load_model(model, ckpt_path) - if optimizer is not None: - optimizer.sync_moe_master_param() - optimizer.update_master_params(model) - - def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: """ Load file in JSON format @@ -90,7 +72,7 @@ def load_checkpoint( """ # Update booster params states. - load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer) + booster.load_model(model, os.path.join(load_dir, "modeling")) booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py index d234fb628..46ff70ff3 100644 --- a/applications/ColossalMoE/infer.py +++ b/applications/ColossalMoE/infer.py @@ -2,10 +2,8 @@ import argparse import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO -from colossal_moe.models.mixtral_layer import replace_moe_layer +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from colossal_moe.utils import load_model from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM @@ -13,9 +11,6 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.moe import MOE_MANAGER -from colossalai.moe.utils import skip_init -from colossalai.utils import get_current_device def parse_args(): @@ -30,16 +25,10 @@ def parse_args(): parser.add_argument( "--plugin", type=str, - default="hybrid", + default="ep", choices=["ep"], help="Parallel methos.", ) - parser.add_argument( - "--output_path", - type=str, - default="./outputs", - help="The path of your saved model after finetuning.", - ) parser.add_argument( "--precision", type=str, @@ -71,60 +60,38 @@ def main(): colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() + config = MixtralConfig.from_pretrained(args.model_name) + ep_size = min(dist.get_world_size(), config.num_local_experts) # Set plugin - booster_kwargs = {} - hybrid_dict = { - "tp_size": 1, - "custom_policy": MixtralForCausalLMPolicy(), - "enable_fused_normalization": args.use_layernorm_kernel, - "enable_jit_fused": args.use_kernel, - "precision": args.precision, - "checkpoint_io": MixtralMoECheckpointIO, - "zero_stage": 1, - } - mgr_dict = {} if args.plugin == "ep": - dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( + tp_size=1, pp_size=1, - **hybrid_dict, - ) - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=dp_size, - **mgr_dict, + ep_size=ep_size, + zero_stage=1, + precision=args.precision, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, ) else: raise ValueError(f"Invalid plugin {args.plugin}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") # Build mixtral model - config = MixtralConfig.from_pretrained(args.model_name) - config.num_local_experts = 1 # dont change this. it will not affect model - with skip_init(): - model = MixtralForCausalLM(config) - model.num_experts = 8 - model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16) - model = model.to(get_current_device()) - coordinator.print_on_master(f"Finish init model with config:\n{config}") - - # Replace moe - with skip_init(): - replace_moe_layer(model) - model.eval() - coordinator.print_on_master(f"Finish replace moe module") + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish load model") # Prepare tokenizer and dataloader tokenizer = AutoTokenizer.from_pretrained(args.model_name) # Set booster - booster = Booster(plugin=plugin, **booster_kwargs) + booster = Booster(plugin=plugin) model, _, _, _, _ = booster.boost(model=model) coordinator.print_on_master(f"Finish init booster") - # load ckpt - load_model(args.model_name, model, booster) - coordinator.print_on_master(f"Finish load ckpt") + model.eval() if coordinator.rank == 0: text = ["Hello my name is"] @@ -132,10 +99,13 @@ def main(): text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] tokenizer.pad_token = tokenizer.unk_token inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) - outputs = model.module.generate(**inputs, max_new_tokens=20) - outputs = tokenizer.batch_decode(outputs) + + with torch.no_grad(): + outputs = model.module.generate(**inputs, max_new_tokens=20) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) print(f"[{coordinator.rank}] {outputs}") + if __name__ == "__main__": main() diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py new file mode 100644 index 000000000..57589ab20 --- /dev/null +++ b/applications/ColossalMoE/tests/test_mixtral_layer.py @@ -0,0 +1,63 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock +from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.moe import MOE_MANAGER +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + MOE_MANAGER.setup( + parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1 + ) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + ) + torch.manual_seed(0) + orig_model = MixtralSparseMoeBlock(config).cuda() + x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() + orig_output, orig_logits = orig_model(x) + model = deepcopy(orig_model) + model = EPMixtralSparseMoeBlock.from_native_module(model) + ep_output, ep_logits = model(x) + assert_close(orig_logits, ep_logits) + assert_close(orig_output, ep_output) + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss) + name_to_p = {n: p for n, p in orig_model.named_parameters()} + for n, ep_p in model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(2) diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py index 7c6012a70..d3848bc14 100644 --- a/applications/ColossalMoE/tests/test_moe_checkpoint.py +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -1,185 +1,144 @@ -import os -import shutil +from copy import deepcopy import pytest import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO -from colossal_moe.models.mixtral_layer import replace_moe_layer +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from torch.optim import Adam +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.moe.manager import MOE_MANAGER -from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn -from colossalai.utils import get_current_device +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 -def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20): - input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device()) - attention_mask = torch.ones_like(input_ids) +def check_model_equal(model1, model2): + assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert torch.equal(p1.half(), p2.half()) + + +def get_optimizer_snapshot(optim): + state = {id(k): deepcopy(v) for k, v in optim.state.items()} + param_groups = [] + for group in optim.param_groups: + params = [id(p) for p in group["params"]] + new_group = {"params": params} + for k, v in group.items(): + if k != "params": + new_group[k] = v + param_groups.append(new_group) return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": input_ids, + "state": state, + "param_groups": param_groups, } -def run_fwd_bwd( - model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None -): - model.train() - if pipeline: - train_dataloader_iter = DummyDataloader(data_gen_fn, length=1) - is_pp_last_stage = booster.plugin.stage_manager.is_last_stage() - y = booster.execute_pipeline( - train_dataloader_iter, - model, - lambda x, y: x.loss, - optimizer, - return_loss=True, - return_outputs=True, - ) - # Backward and optimize - if is_pp_last_stage: - loss = y["loss"] - else: - if criterion: - y = model(data).logits - loss = criterion(y) - else: - loss = model(data, label) - loss = loss.float() - - if optimizer is not None: - optimizer.backward(loss) - else: - loss.backward() - return y +def check_optimizer_snapshot_equal(snapshot1, snapshot2): + # check param_groups + assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) + for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): + assert set(group1.keys()) == set(group2.keys()) + for k in group1.keys(): + assert group1[k] == group2[k] + # check state + assert set(snapshot1["state"].keys()) == set( + snapshot2["state"].keys() + ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}" + for pid in snapshot1["state"].keys(): + state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid] + assert set(state1.keys()) == set(state2.keys()) + for k in state1.keys(): + if isinstance(state1[k], torch.Tensor): + assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}" + else: + assert state1[k] == state2[k] -def get_config(): +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) config = MixtralConfig( - vocab_size=300, - hidden_size=32, - intermediate_size=16, - num_hidden_layers=2, - dropout_rate=0.0, + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, ) - return config - - -def get_model(parallel): - config = get_config() - model = MixtralForCausalLM(config).to(torch.bfloat16) - replace_moe_layer(model) - optim = torch.optim.Adam(model.parameters()) - args = dict( - precision="bf16", + torch.manual_seed(0) + input_ids = torch.randint(0, 100, (2, tokens)).cuda() + orig_model = MixtralForCausalLM(config).cuda() + model = deepcopy(orig_model) + optimizer = Adam(model.parameters(), lr=1e-3) + plugin = MoeHybridParallelPlugin( tp_size=1, - zero_stage=1, + pp_size=2, + ep_size=2, custom_policy=MixtralForCausalLMPolicy(), - checkpoint_io=MixtralMoECheckpointIO, + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + microbatch_size=1, + zero_stage=1, ) - if parallel == "ep": - plugin = MoeHybridParallelPlugin( - pp_size=1, - **args, - ) - elif parallel == "hybrid": - plugin = MoeHybridParallelPlugin( - pp_size=2, - microbatch_size=1, - **args, - ) booster = Booster(plugin=plugin) - model, optim, _, _, _ = booster.boost(model=model, optimizer=optim) - return model, booster, optim - - -def _test_moe_checkpoint(parallel): - if dist.get_rank() == 0: - if os.path.exists("./tmp_ckpt1"): - shutil.rmtree("./tmp_ckpt1") - if os.path.exists("./tmp_ckpt2"): - shutil.rmtree("./tmp_ckpt2") - dist.barrier() - - if parallel == None: - MOE_MANAGER.setup( - parallel=None, - ) - elif parallel == "ep": - MOE_MANAGER.setup( - parallel="EP", - ) - elif parallel == "hybrid": - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=1, - fixed_ep_size=2, - fixed_pp_size=2, - ) - model1, booster1, optim1 = get_model(parallel) - model2, booster2, optim2 = get_model(parallel) - # param ckpt - # check not equal - try: - check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) - raise AssertionError("state_dict should not be equal") - except: - pass - # shard - booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1) - booster2.load_model(model2, "./tmp_ckpt1") - # check - check_state_dict_equal(model1.state_dict(), model2.state_dict(), False) - - # optim ckpt - criterion = lambda x: x.mean() - data = torch.randint(0, 4, (2, 4)).cuda() - label = torch.randint(0, 4, (2,)).cuda() - if parallel == "hybrid": - kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin} - else: - kwargs = {} - run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs) - optim1.step() - optim1.zero_grad() - # shard - booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1) - dist.barrier() - booster2.load_optimizer(optim2, "./tmp_ckpt2") - # check - check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False) - - if dist.get_rank() == 0: - shutil.rmtree("./tmp_ckpt1") - shutil.rmtree("./tmp_ckpt2") - - -def _run_dist(rank, world_size, port, parallel): - colossalai.launch( - config=dict(), - rank=rank, - world_size=world_size, - host="localhost", - port=port, - backend="nccl", + model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) + # initialize grads + data_iter = iter( + [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}] ) - _test_moe_checkpoint(parallel) + booster.execute_pipeline( + data_iter, + model, + lambda outputs, inputs: outputs.loss, + optimizer, + ) + + # check save model + booster.save_model(model, "mixtral_model", shard=True) + dist.barrier() + if dist.get_rank() == 0: + saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda() + check_model_equal(orig_model, saved_model) + saved_model.save_pretrained("mixtral_hf_model") + dist.barrier() + + # check load model + new_model = MixtralForCausalLM(config).cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) + booster.load_model(new_model, "mixtral_hf_model") + check_model_equal(model, new_model) + + # check save optimizer + optimizer.step() + snapshot = get_optimizer_snapshot(optimizer.unwrap()) + booster.save_optimizer(optimizer, "mixtral_optim", shard=True) + dist.barrier() + # reset optimizer state + for state in optimizer.unwrap().state.values(): + for v in state.values(): + if isinstance(v, torch.Tensor): + v.zero_() + booster.load_optimizer(optimizer, "mixtral_optim") + loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) + check_optimizer_snapshot_equal(snapshot, loaded_snapshot) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() -@pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) -@pytest.mark.parametrize("parallel", ["ep", "hybrid"]) -@rerun_if_address_is_in_use() -def test_moe_checkpoint(world_size, parallel): - spawn(_run_dist, world_size, parallel=parallel) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) if __name__ == "__main__": - test_moe_checkpoint(world_size=4, parallel="hybrid") + test_mixtral_moe_layer(4) diff --git a/applications/ColossalMoE/tests/test_moe_layer.py b/applications/ColossalMoE/tests/test_moe_layer.py deleted file mode 100644 index 8b090c427..000000000 --- a/applications/ColossalMoE/tests/test_moe_layer.py +++ /dev/null @@ -1,31 +0,0 @@ -import copy - -import torch -from colossal_moe.models.mixtral_layer import MixtralSparseMLP -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - -class Config: - def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act): - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_local_experts = num_local_experts - self.num_experts_per_tok = num_experts_per_tok - self.hidden_act = hidden_act - - -def test_moe_layer(): - config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu") - mistral_moe = MixtralSparseMoeBlock(config).cuda() - colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda() - - data = torch.randn(2, 8, 4).cuda() - mistral_output = mistral_moe(data)[0] - colossal_output = colossal_moe(data)[0] - assert torch.allclose( - mistral_output, colossal_output - ), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}" - - -if __name__ == "__main__": - test_moe_layer() diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index 1d0441a5a..c567038ec 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -2,22 +2,18 @@ import argparse import torch import torch.distributed as dist -from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO -from colossal_moe.models.mixtral_layer import replace_moe_layer +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint +from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer -from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM +from transformers.models.mixtral import MixtralForCausalLM import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.moe import MOE_MANAGER, apply_load_balance -from colossalai.moe.layers import apply_load_balance -from colossalai.moe.manager import MOE_MANAGER from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -153,45 +149,27 @@ def main(): coordinator = DistCoordinator() # Set plugin - booster_kwargs = {} - hybrid_dict = { - "tp_size": 1, - "custom_policy": MixtralForCausalLMPolicy(), - "enable_fused_normalization": args.use_layernorm_kernel, - "enable_jit_fused": args.use_kernel, - "precision": args.precision, - "zero_stage": args.zero_stage, - "checkpoint_io": MixtralMoECheckpointIO, - } - mgr_dict = {} if args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( + tp_size=1, pp_size=args.pp_size, + ep_size=args.ep_size, microbatch_size=args.microbatch_size, - **hybrid_dict, - ) - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=args.dp_size, - fixed_ep_size=args.ep_size, - fixed_pp_size=args.pp_size, - **mgr_dict, + custom_policy=MixtralForCausalLMPolicy(), + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + precision=args.precision, + zero_stage=args.zero_stage, + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, ) + else: raise ValueError(f"Invalid plugin {args.plugin}") coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") # Build Mixtral model - config = MixtralConfig.from_pretrained(args.model_name) - config.use_cache = False - config.num_local_experts = 1 - model = MixtralForCausalLM(config) - model.num_experts = 8 - model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16) - model = model.to(get_current_device()) - replace_moe_layer(model, enable_kernel=args.use_kernel) - coordinator.print_on_master(f"Finish init model with config:\n{config}") + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish init model") # Enable gradient checkpointing model.gradient_checkpointing_enable() @@ -224,7 +202,7 @@ def main(): ) # Set booster - booster = Booster(plugin=plugin, **booster_kwargs) + booster = Booster(plugin=plugin) model, optimizer, _, dataloader, lr_scheduler = booster.boost( model=model, optimizer=optimizer, @@ -236,10 +214,7 @@ def main(): coordinator.print_on_master(f"Finish init booster") # Load ckpt - if args.load_checkpoint is None: - load_model(args.model_name, model, booster, optimizer) - coordinator.print_on_master(f"Finish load checkpoint") - else: + if args.load_checkpoint is not None: load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) coordinator.print_on_master(f"Finish load optimizer") @@ -286,13 +261,13 @@ def main(): optimizer.zero_grad() # Apply load balance - if ( - args.load_balance - and args.load_balance_interval > 0 - and (step + 1) % args.load_balance_interval == 0 - ): - coordinator.print_on_master(f"Apply load balance") - apply_load_balance(model, optimizer) + # if ( + # args.load_balance + # and args.load_balance_interval > 0 + # and (step + 1) % args.load_balance_interval == 0 + # ): + # coordinator.print_on_master(f"Apply load balance") + # apply_load_balance(model, optimizer) # save ckeckpoint if (step + 1) % args.save_interval == 0: coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 07cbc14a7..45e5a23c1 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ) from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MoECheckpintIO +from colossalai.moe import MOE_MANAGER, MoECheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self, tp_size: int, pp_size: int, + ep_size: int, extra_dp_size: int = 1, precision: str = "fp16", zero_stage: int = 0, @@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): if enable_sequence_parallelism: assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" - + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + dist.get_world_size() % (tp_size * pp_size * ep_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" + self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=self.real_dp_size, + fixed_ep_size=ep_size, + fixed_pp_size=pp_size, + use_ep_inside=use_ep_inside, + ) self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.ep_size = ep_size + self.moe_info = MOE_MANAGER.get_info(0)[1] self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 780117598..712324215 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import ModelWrapper -from .utils import has_index_file +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file __all__ = ["CheckpointIO"] @@ -90,7 +90,15 @@ class CheckpointIO(ABC): if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: - self.load_unsharded_model(model, checkpoint, strict) + path = Path(checkpoint, SAFE_WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + path = Path(checkpoint, WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + self.load_unsharded_model(model, checkpoint, strict) return origin_model diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 34342436f..01c837ee3 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.distributed as dist @@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function): if ctx.ep_size != 1: grad = grad / ctx.ep_size return grad, None + + +def _all_to_all( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + async_op: bool = False, +): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + outputs_shape = list(inputs.shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) + inputs = inputs.contiguous() + outputs = outputs.contiguous() + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) + return outputs, handle + + +class AllToAllUneven(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + input_split_sizes=None, + output_split_sizes=None, + group=None, + overlap: bool = False, + ): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.group = group + return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + + @staticmethod + def backward(ctx: Any, *grad_outputs): + return ( + _all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0], + None, + None, + None, + None, + ) + + +def all_to_all_uneven( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + overlap: bool = False, +): + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ba6c77056..5ac3c2b3a 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -26,3 +26,5 @@ class MoeParallelInfo: self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) + self.ep_rank = self.pg.coordinate(self.ep_axis) + self.dp_rank = self.pg.coordinate(self.dp_axis) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 47bc7603a..511eb26e8 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] - def sync_moe_master_param(self): - for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): - master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach() - def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping. @@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.copy_(working_moe_param) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: - return self._param_store.master_to_working_param + return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}