diff --git a/applications/ColossalMoE/README.md b/applications/ColossalMoE/README.md new file mode 100644 index 000000000..be50a8f9f Binary files /dev/null and b/applications/ColossalMoE/README.md differ diff --git a/applications/ColossalMoE/colossal_moe/__init__.py b/applications/ColossalMoE/colossal_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/colossal_moe/models/__init__.py b/applications/ColossalMoE/colossal_moe/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py new file mode 100644 index 000000000..d08dfd5f8 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -0,0 +1,629 @@ +import copy +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Dict, Iterator, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.checkpoint_io import CheckpointIndexFile +from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO +from colossalai.checkpoint_io.index_file import CheckpointIndexFile +from colossalai.checkpoint_io.utils import ( + StateDictSharder, + gather_distributed_param, + get_model_base_filenames, + get_optimizer_base_filenames, + load_shard_state_dict, + load_states_into_optimizer, + save_config_file, + save_param_groups, + save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, +) +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.moe import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + + +class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): + def __init__( + self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True, + ) -> None: + super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose) + moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size] + self.ep_group = moe_info.ep_group + self.ep_size = moe_info.ep_size + self.ep_rank = moe_info.ep_rank + self.real_dp_rank = moe_info.dp_rank + + @staticmethod + def _model_sharder( + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + param_name_pattern: Optional[str] = None, + ) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + if param_name_pattern is not None and param_name_pattern not in name: + continue + # Gather tensor pieces when using tensor parallel. + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if ( + getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) + is not torch.nn.Module.get_extra_state + ): + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model = model.unwrap() + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + if self.real_dp_rank != 0: + dist.barrier() + return + + # ep_rank 0 saves all the parameters and buffers. + # other ep_ranks save only experts + ep_param_pattern = "experts." if self.ep_rank != 0 else None + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_rank == 0 are responsible for model saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern + ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + dist.barrier() + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + weights_name = weights_name.replace( + ".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors" + ) + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + dist.barrier() + return + + dist.barrier() + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + save_config_file(model, checkpoint) + rmtree(tmp_index_file_folder) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + @staticmethod + def gather_from_sharded_optimizer_state( + state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + use_zero: bool, + inplace: bool, + is_moe_param: bool, + device: torch.device = torch.device("cpu"), + ) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu'). + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # First gather Zero shards. + if use_zero and not is_moe_param: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().to(device) + + return state_ + + @staticmethod + def _optimizer_sharder( + optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + size_per_shard: int = 1024, + only_moe_param: bool = False, + ): + # An internel method that breaks state_dict of optimizer into shards within limited size. + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + master_to_working_map = optimizer.get_master_to_working_map() + + for param, state in optimizer.optim.state.items(): + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info["param2id"][id(working_param)] + original_shape = param_info["param2shape"][id(working_param)] + state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state( + state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False, + is_moe_param=is_moe_tensor(working_param), + ) + + if only_moe_param and not is_moe_tensor(working_param): + continue + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.real_dp_rank != 0: + dist.barrier() + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + size_per_shard=size_per_shard, + only_moe_param=self.ep_rank != 0, + ) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = self.real_dp_rank == 0 and self.tp_rank == 0 + + if self.pp_size == 1 and self.ep_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + dist.barrier() + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) + + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + dist.barrier() + return + + dist.barrier() + + # The global master rank integrates the index files and clean the folder. + if self.coordinator.is_master(): + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = [ + {**group, "params": group_info["params"]} + for group, group_info in zip(optimizer.param_groups, optimizer.param_info["param_groups"]) + ] + save_param_groups({"param_groups": param_groups}, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + + def _get_param_id_from_optimizer_param( + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + ): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info["param2id"][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + master_to_working_map = optimizer.get_master_to_working_map() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + id_map[param_id] = param + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + # ep param groups + if len(optimizer.optim.param_groups) == len(saved_groups) + 1: + new_pg = copy.deepcopy(saved_pg) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({"param_groups": updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg["params"]: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state( + state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True, + is_moe_param=is_moe_tensor(working_param), + ) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose and self.coordinator.is_master(): + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + + def shard_from_complete_optimizer_state( + self, + state: OrderedDict, + current_shape: torch.Size, + original_shape: torch.Size, + device: torch.device, + inplace: bool, + is_moe_param: bool, + ) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != "step": + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero and not is_moe_param: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False): + raise NotImplementedError diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py new file mode 100644 index 000000000..a2b78a2bd --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_layer.py @@ -0,0 +1,92 @@ +import torch +import torch.distributed as dist +import torch.nn.functional as F +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from colossalai.lazy import LazyInitContext +from colossalai.moe import MOE_MANAGER +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven +from colossalai.shardformer.shard.utils import set_tensors_to_none +from colossalai.tensor.moe_tensor.api import set_moe_tensor_info + + +class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock): + def __init__(self, config): + super().__init__(config) + self.setup_ep() + + def setup_ep(self): + _, moe_info = MOE_MANAGER.get_info(self.num_experts) + ep_group = moe_info.ep_group + self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1 + self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0 + assert self.num_experts % self.ep_size == 0 + self.ep_group = ep_group + self.num_experts_per_ep = self.num_experts // self.ep_size + self.expert_start_idx = self.ep_rank * self.num_experts_per_ep + held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep] + set_tensors_to_none(self.experts, exclude=set(held_experts)) + for p in self.experts.parameters(): + set_moe_tensor_info(p, moe_info) + + @staticmethod + def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock": + LazyInitContext.materialize(module) + module.__class__ = EPMixtralSparseMoeBlock + module.setup_ep() + return module + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + selected_experts = selected_experts.t().reshape(-1) + selected_experts_idx = selected_experts.argsort() + dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx] + input_split_sizes = selected_experts.bincount(minlength=self.num_experts) + output_split_sizes = torch.zeros_like(input_split_sizes) + dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group) + + input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist() + output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group) + # compute expert output + output_states = MoeInGradScaler.apply(output_states, self.ep_size) + if output_states.size(0) > 0: + if self.num_experts_per_ep == 1: + # no need to split + expert = self.experts[self.expert_start_idx] + output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states) + output_states = expert.w2(output_states) + else: + output_states_splits = output_states.split(output_split_sizes.tolist()) + output_states_list = [] + for i, split_states in enumerate(output_states_splits): + if split_states.size(0) == 0: + continue + expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep] + split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states) + split_states = expert.w2(split_states) + output_states_list.append(split_states) + output_states = torch.cat(output_states_list) + output_states = MoeOutGradScaler.apply(output_states, self.ep_size) + dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group) + recover_experts_idx = torch.empty_like(selected_experts_idx) + recover_experts_idx[selected_experts_idx] = torch.arange( + selected_experts_idx.size(0), device=selected_experts_idx.device + ) + dispatch_states = dispatch_states[recover_experts_idx] + k_hidden_states = dispatch_states.chunk(self.top_k) + output_states = k_hidden_states[0] * routing_weights[:, 0, None] + for i in range(1, self.top_k): + output_states += k_hidden_states[i] * routing_weights[:, i, None] + output_states = output_states.reshape(batch_size, sequence_length, hidden_dim) + return output_states, router_logits diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py new file mode 100644 index 000000000..218b05b27 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -0,0 +1,557 @@ +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.models.mixtral.modeling_mixtral import ( + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MoeCausalLMOutputWithPast, + _prepare_4d_causal_attention_mask, + load_balancing_loss_func, +) +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.shard import ShardConfig + +from .mixtral_layer import EPMixtralSparseMoeBlock + +__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"] + + +class MixtralPolicy(Policy): + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag." + ) + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.") + + # expert parallel + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="block_sparse_moe", + target_module=EPMixtralSparseMoeBlock, + ) + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + ], + policy=policy, + target_key=MixtralDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=MixtralModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in mixtral.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=model_cls + ) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "MixtralModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class MixtralModelPolicy(MixtralPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralModel, + new_forward=MixtralPipelineForwards.mixtral_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class MixtralForCausalLMPolicy(MixtralPolicy): + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + MixtralForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ] + ) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=MixtralForCausalLM, + new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): + # tie weights + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] + return [] + + +class MixtralPipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def mixtral_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + output_router_logits, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + output_router_logits, + use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + if output_router_logits and past_router_logits is not None: + all_router_logits = past_router_logits + all_router_logits + if stage_manager.is_last_stage(): + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "past_router_logits": all_router_logits, + } + + @staticmethod + def mixtral_for_causal_lm_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + past_router_logits: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + shard_config: ShardConfig = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = MixtralPipelineForwards.mixtral_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_logits=past_router_logits, + ) + past_key_values = None + + if stage_manager.is_last_stage(): + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=None, + hidden_states=outputs[0], + attentions=None, + router_logits=outputs[-1], + ) + else: + out = {} + hidden_states = outputs.get("hidden_states") + out["hidden_states"] = hidden_states + if output_router_logits: + out["past_router_logits"] = outputs["past_router_logits"] + return out diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py new file mode 100644 index 000000000..a2a0a7e78 --- /dev/null +++ b/applications/ColossalMoE/colossal_moe/utils.py @@ -0,0 +1,84 @@ +import json +import os +from typing import Any, Dict, Tuple, Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + +from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + booster.load_model(model, os.path.join(load_dir, "modeling")) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py new file mode 100644 index 000000000..46ff70ff3 --- /dev/null +++ b/applications/ColossalMoE/infer.py @@ -0,0 +1,111 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--plugin", + type=str, + default="ep", + choices=["ep"], + help="Parallel methos.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + config = MixtralConfig.from_pretrained(args.model_name) + ep_size = min(dist.get_world_size(), config.num_local_experts) + # Set plugin + if args.plugin == "ep": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + ep_size=ep_size, + zero_stage=1, + precision=args.precision, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build mixtral model + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish load model") + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + # Set booster + booster = Booster(plugin=plugin) + model, _, _, _, _ = booster.boost(model=model) + coordinator.print_on_master(f"Finish init booster") + + model.eval() + + if coordinator.rank == 0: + text = ["Hello my name is"] + else: + text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"] + tokenizer.pad_token = tokenizer.unk_token + inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device()) + + with torch.no_grad(): + outputs = model.module.generate(**inputs, max_new_tokens=20) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + print(f"[{coordinator.rank}] {outputs}") + + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/infer.sh b/applications/ColossalMoE/infer.sh new file mode 100644 index 000000000..0487fe9c1 --- /dev/null +++ b/applications/ColossalMoE/infer.sh @@ -0,0 +1,7 @@ +NUM_GPU=2 +MODEL="mistralai/Mixtral-8x7B-v0.1" + +# ep +torchrun --standalone --nproc_per_node $NUM_GPU infer.py \ + --model_name $MODEL \ + --plugin "ep" \ diff --git a/applications/ColossalMoE/requirements.txt b/applications/ColossalMoE/requirements.txt new file mode 100644 index 000000000..9a5738c41 --- /dev/null +++ b/applications/ColossalMoE/requirements.txt @@ -0,0 +1,5 @@ +colossalai >= 0.3.3 +torch >= 1.8.1 +transformers == 4.36.0 +sentencepiece +datasets diff --git a/applications/ColossalMoE/setup.py b/applications/ColossalMoE/setup.py new file mode 100644 index 000000000..275f59e10 --- /dev/null +++ b/applications/ColossalMoE/setup.py @@ -0,0 +1,43 @@ +from setuptools import find_packages, setup + + +def fetch_requirements(path): + with open(path, "r") as fd: + return [r.strip() for r in fd.readlines()] + + +def fetch_readme(): + with open("README.md", encoding="utf-8") as f: + return f.read() + + +def fetch_version(): + with open("version.txt", "r") as f: + return f.read().strip() + + +setup( + name="colossal_moe", + version=fetch_version(), + packages=find_packages( + exclude=( + "tests", + "benchmarks", + "*.egg-info", + ) + ), + description="Colossal-AI MoE", + long_description=fetch_readme(), + long_description_content_type="text/markdown", + license="Apache Software License 2.0", + url="https://github.com/hpcaitech", + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.6", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Environment :: GPU :: NVIDIA CUDA", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: System :: Distributed Computing", + ], +) diff --git a/applications/ColossalMoE/tests/__init__.py b/applications/ColossalMoE/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/ColossalMoE/tests/test_mixtral_layer.py b/applications/ColossalMoE/tests/test_mixtral_layer.py new file mode 100644 index 000000000..57589ab20 --- /dev/null +++ b/applications/ColossalMoE/tests/test_mixtral_layer.py @@ -0,0 +1,63 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock +from torch.testing import assert_close +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +import colossalai +from colossalai.moe import MOE_MANAGER +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + MOE_MANAGER.setup( + parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1 + ) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + ) + torch.manual_seed(0) + orig_model = MixtralSparseMoeBlock(config).cuda() + x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda() + orig_output, orig_logits = orig_model(x) + model = deepcopy(orig_model) + model = EPMixtralSparseMoeBlock.from_native_module(model) + ep_output, ep_logits = model(x) + assert_close(orig_logits, ep_logits) + assert_close(orig_output, ep_output) + orig_loss = orig_output.mean() + orig_loss.backward() + ep_loss = ep_output.mean() + ep_loss.backward() + assert_close(orig_loss, ep_loss) + name_to_p = {n: p for n, p in orig_model.named_parameters()} + for n, ep_p in model.named_parameters(): + p = name_to_p[n] + if ep_p.grad is not None: + assert_close(p.grad, ep_p.grad) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(2) diff --git a/applications/ColossalMoE/tests/test_moe_checkpoint.py b/applications/ColossalMoE/tests/test_moe_checkpoint.py new file mode 100644 index 000000000..822e7410f --- /dev/null +++ b/applications/ColossalMoE/tests/test_moe_checkpoint.py @@ -0,0 +1,146 @@ +from copy import deepcopy + +import pytest +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from torch.optim import Adam +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.testing.utils import spawn + +tokens, n_experts = 7, 4 +hidden_size = 8 +top_k = 2 + + +def check_model_equal(model1, model2): + assert set(model1.state_dict().keys()) == set(model2.state_dict().keys()) + for p1, p2 in zip(model1.parameters(), model2.parameters()): + assert torch.equal(p1.half(), p2.half()) + + +def get_optimizer_snapshot(optim): + state = {id(k): deepcopy(v) for k, v in optim.state.items()} + param_groups = [] + for group in optim.param_groups: + params = [id(p) for p in group["params"]] + new_group = {"params": params} + for k, v in group.items(): + if k != "params": + new_group[k] = v + param_groups.append(new_group) + return { + "state": state, + "param_groups": param_groups, + } + + +def check_optimizer_snapshot_equal(snapshot1, snapshot2): + # check param_groups + assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"]) + for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]): + assert set(group1.keys()) == set(group2.keys()) + for k in group1.keys(): + assert group1[k] == group2[k] + # check state + assert set(snapshot1["state"].keys()) == set( + snapshot2["state"].keys() + ), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}" + for pid in snapshot1["state"].keys(): + state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid] + assert set(state1.keys()) == set(state2.keys()) + for k in state1.keys(): + if isinstance(state1[k], torch.Tensor): + assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}" + else: + assert state1[k] == state2[k] + + +def check_mixtral_moe_layer(): + torch.cuda.set_device(dist.get_rank()) + config = MixtralConfig( + hidden_size=hidden_size, + intermediate_size=hidden_size * 2, + num_local_experts=n_experts, + num_experts_per_tok=top_k, + num_attention_heads=2, + num_key_value_heads=2, + ) + torch.manual_seed(0) + input_ids = torch.randint(0, 100, (2, tokens)).cuda() + orig_model = MixtralForCausalLM(config).cuda() + model = deepcopy(orig_model) + optimizer = Adam(model.parameters(), lr=1e-3) + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=2, + ep_size=2, + custom_policy=MixtralForCausalLMPolicy(), + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + microbatch_size=1, + zero_stage=1, + ) + booster = Booster(plugin=plugin) + model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer) + # initialize grads + data_iter = iter( + [{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}] + ) + booster.execute_pipeline( + data_iter, + model, + lambda outputs, inputs: outputs.loss, + optimizer, + ) + + # check save model + booster.save_model(model, "mixtral_model", shard=True) + dist.barrier() + if dist.get_rank() == 0: + saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda() + check_model_equal(orig_model, saved_model) + saved_model.save_pretrained("mixtral_hf_model") + dist.barrier() + + # check load model + new_model = MixtralForCausalLM(config).cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer) + booster.load_model(new_model, "mixtral_hf_model") + check_model_equal(model, new_model) + + # check save optimizer + optimizer.step() + for group in optimizer.param_groups: + group["lr"] = 0.1 + snapshot = get_optimizer_snapshot(optimizer.unwrap()) + booster.save_optimizer(optimizer, "mixtral_optim", shard=True) + dist.barrier() + # reset optimizer state + for state in optimizer.unwrap().state.values(): + for v in state.values(): + if isinstance(v, torch.Tensor): + v.zero_() + booster.load_optimizer(optimizer, "mixtral_optim") + loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap()) + check_optimizer_snapshot_equal(snapshot, loaded_snapshot) + + +def run_dist(rank: int, world_size: int, port: int): + colossalai.launch({}, rank, world_size, "localhost", port) + check_mixtral_moe_layer() + + +@pytest.mark.parametrize("world_size", [4]) +def test_mixtral_moe_layer(world_size: int): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_mixtral_moe_layer(4) diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py new file mode 100644 index 000000000..c567038ec --- /dev/null +++ b/applications/ColossalMoE/train.py @@ -0,0 +1,295 @@ +import argparse + +import torch +import torch.distributed as dist +from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO +from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import AutoTokenizer +from transformers.models.mixtral import MixtralForCausalLM + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +@torch.no_grad() +def get_global_loss(loss, booster): + global_loss = loss.clone().detach() + dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group) + global_loss.div_(booster.plugin.dp_size) + return global_loss + + +class RandomDataset(Dataset): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="mistralai/Mixtral-8x7B-v0.1", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + choices=["hybrid"], + help="Parallel methods.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./outputs", + help="The path of your saved model after finetuning.", + ) + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--save_interval", + type=int, + default=1000, + help=" The interval (steps) of saving checkpoints.", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16", + choices=["fp32", "bf16", "fp16"], + help="The mixed precision training.", + ) + parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + + # optim + parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + + # lr scheduler + parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + + # zero stage for all plugins + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.") + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin") + parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin") + parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin") + + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.", + ) + parser.add_argument( + "--use_layernorm_kernel", + action="store_true", + help="Use layernorm kernel. Need to install apex. Raise error if not installed.", + ) + + # load balance + parser.add_argument( + "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable." + ) + parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.") + # communicate overlap + parser.add_argument( + "--comm_overlap", + action="store_true", + help="Use communication overlap for MoE. Recommended to enable for muiti-node training.", + ) + # hierarchical all-to-all + parser.add_argument( + "--hierarchical_alltoall", + action="store_true", + help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set plugin + if args.plugin == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + ep_size=args.ep_size, + microbatch_size=args.microbatch_size, + custom_policy=MixtralForCausalLMPolicy(), + enable_fused_normalization=args.use_layernorm_kernel, + enable_jit_fused=args.use_kernel, + precision=args.precision, + zero_stage=args.zero_stage, + checkpoint_io=MixtralMoEHybridParallelCheckpointIO, + ) + + else: + raise ValueError(f"Invalid plugin {args.plugin}") + coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}") + + # Build Mixtral model + model = MixtralForCausalLM.from_pretrained(args.model_name) + coordinator.print_on_master(f"Finish init model") + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + dataset = RandomDataset(num_samples=100, tokenizer=tokenizer) + collate_fn = None + dataloader = plugin.prepare_dataloader( + dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn + ) + + # Set optimizer + optimizer = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + # Set lr scheduler + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * len(dataloader), + warmup_steps=args.warmup_steps + if args.warmup_steps is not None + else int(args.num_epochs * len(dataloader) * 0.025), + eta_min=0.1 * args.lr, + ) + + # Set booster + booster = Booster(plugin=plugin) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1 + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + coordinator.print_on_master(f"Finish init booster") + + # Load ckpt + if args.load_checkpoint is not None: + load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) + coordinator.print_on_master(f"Finish load optimizer") + + # Start finetuning + coordinator.print_on_master(f"Start finetuning") + for epoch in range(args.num_epoch): + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage, + ) as pbar: + for step in pbar: + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + global_loss = get_global_loss(loss, booster) + if coordinator._local_rank == "0": + pbar.set_postfix({"Loss": global_loss.item()}) + else: + # Forward pass + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Apply load balance + # if ( + # args.load_balance + # and args.load_balance_interval > 0 + # and (step + 1) % args.load_balance_interval == 0 + # ): + # coordinator.print_on_master(f"Apply load balance") + # apply_load_balance(model, optimizer) + # save ckeckpoint + if (step + 1) % args.save_interval == 0: + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + save_checkpoint( + args.output_path, + booster, + model, + optimizer, + lr_scheduler, + epoch, + step, + args.batch_size, + coordinator, + ) + + # save checkpoint at the end of each epochs + booster.save_model(model, args.output_path, shard=True, size_per_shard=5120) + coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") + + # Finish training + coordinator.print_on_master(f"Finish training") + + +if __name__ == "__main__": + main() diff --git a/applications/ColossalMoE/train.sh b/applications/ColossalMoE/train.sh new file mode 100644 index 000000000..bee7f5c8f --- /dev/null +++ b/applications/ColossalMoE/train.sh @@ -0,0 +1,19 @@ +NUM_GPU=8 +MODEL="mistralai/Mixtral-8x7B-v0.1" +SEQ_LENGTH=2048 +BATCH_SIZE=1 +LR=0.00001 + +# hybrid +# torchrun --standalone --nproc_per_node $NUM_GPU \ +colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \ + train.py \ + --num_epoch 1 \ + --model_name $MODEL \ + --plugin "hybrid" \ + --batch_size $BATCH_SIZE \ + --lr $LR \ + --zero_stage 1 \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 8 \ diff --git a/applications/ColossalMoE/version.txt b/applications/ColossalMoE/version.txt new file mode 100644 index 000000000..3eefcb9dd --- /dev/null +++ b/applications/ColossalMoE/version.txt @@ -0,0 +1 @@ +1.0.0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index e976d0aaf..45e5a23c1 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ) from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MoECheckpintIO +from colossalai.moe import MOE_MANAGER, MoECheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig @@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self, tp_size: int, pp_size: int, + ep_size: int, extra_dp_size: int = 1, precision: str = "fp16", zero_stage: int = 0, @@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_communication: bool = True, use_ep_inside: bool = True, custom_policy: Policy = None, + checkpoint_io: Optional[MoECheckpintIO] = None, ) -> None: assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): if enable_sequence_parallelism: assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism" - + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + assert ( + dist.get_world_size() % (tp_size * pp_size * ep_size) == 0 + ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}" + self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size) + MOE_MANAGER.setup( + parallel="EP", + mode="fixed", + fixed_dp_size=self.real_dp_size, + fixed_ep_size=ep_size, + fixed_pp_size=pp_size, + use_ep_inside=use_ep_inside, + ) self.tp_size = tp_size self.pp_size = pp_size self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.ep_size = ep_size + self.moe_info = MOE_MANAGER.get_info(0)[1] self.precision = precision self.zero_stage = zero_stage self.cpu_offload = cpu_offload @@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + self.checkpoint_io = checkpoint_io # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) @@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) def get_checkpoint_io(self) -> MoECheckpintIO: - self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + if self.checkpoint_io is None: + self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + else: + self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io def configure( diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 780117598..712324215 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import ModelWrapper -from .utils import has_index_file +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file __all__ = ["CheckpointIO"] @@ -90,7 +90,15 @@ class CheckpointIO(ABC): if index_file_exists: self.load_sharded_model(model, index_file_path, strict) else: - self.load_unsharded_model(model, checkpoint, strict) + path = Path(checkpoint, SAFE_WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + path = Path(checkpoint, WEIGHTS_NAME) + if path.is_file(): + self.load_unsharded_model(model, str(path), strict) + else: + self.load_unsharded_model(model, checkpoint, strict) return origin_model diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 721da69d0..6dd0a5fc3 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,6 +1,7 @@ from .checkpoint import MoECheckpintIO from .experts import MLPExperts -from .layers import SparseMLP +from .layers import SparseMLP, apply_load_balance +from .manager import MOE_MANAGER from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator @@ -14,4 +15,6 @@ __all__ = [ "UniformNoiseGenerator", "SparseMLP", "MoECheckpintIO", + "MOE_MANAGER", + "apply_load_balance", ] diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 34342436f..01c837ee3 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import torch import torch.distributed as dist @@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function): if ctx.ep_size != 1: grad = grad / ctx.ep_size return grad, None + + +def _all_to_all( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + async_op: bool = False, +): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + outputs_shape = list(inputs.shape) + if output_split_sizes is not None: + outputs_shape[0] = sum(output_split_sizes) + outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device) + inputs = inputs.contiguous() + outputs = outputs.contiguous() + handle = dist.all_to_all_single( + outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op + ) + return outputs, handle + + +class AllToAllUneven(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + input_split_sizes=None, + output_split_sizes=None, + group=None, + overlap: bool = False, + ): + """ + Returns: + outputs: Tensor + handle: Optional[Work], if overlap is True + """ + ctx.input_split_sizes = input_split_sizes + ctx.output_split_sizes = output_split_sizes + ctx.group = group + return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap) + + @staticmethod + def backward(ctx: Any, *grad_outputs): + return ( + _all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0], + None, + None, + None, + None, + ) + + +def all_to_all_uneven( + inputs: torch.Tensor, + input_split_sizes: Optional[List[int]] = None, + output_split_sizes: Optional[List[int]] = None, + group=None, + overlap: bool = False, +): + return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap) diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index a8c50eab6..b37ffabea 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ + torch.cuda.empty_cache() if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): f"index located at {save_index_file}." ) dist.barrier() + torch.cuda.empty_cache() # ======================================================== # Abstract methods for optimizer loading/saving implementation @@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO): assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" def _get_param_id_from_optimizer_param( - param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None + param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None ): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param return optimizer.param_info["param2id"][id(working_param)] @@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): master_to_working_map = optimizer.get_master_to_working_map() for pg in optimizer.optim.param_groups: for param in pg["params"]: - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) id_map[param_id] = param # Read checkpoint index file. @@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO): new_pg = copy.deepcopy(saved_pg) new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. updated_groups.append(new_pg) - # ep extra group - if MOE_MANAGER.parallel == "EP": + # ep param group + if len(optimizer.optim.param_groups) > len(saved_groups): new_pg = copy.deepcopy(saved_pg) - new_pg["params"] = optimizer.optim.param_groups[-1][ - "params" - ] # Only keep the parameters kept by current pipeline stage. - for param in new_pg["params"]: - param.data = param.data.to(torch.float32) + new_pg["params"] = optimizer.optim.param_groups[-1]["params"] updated_groups.append(new_pg) optimizer.optim.__dict__.update({"param_groups": updated_groups}) @@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): for param in pg["params"]: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) + param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + + # Then shard the loaded optimizer states if using tp/zero. + for pid, state in list(state_dict.items()): + if pid in id_map: + param = id_map[pid] + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + elif ( + hasattr(optimizer, "moe_master_to_working_map") + and id(param) in optimizer.moe_master_to_working_map + ): + working_param = optimizer.moe_master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + working_param, + current_shape=working_param.shape, + original_shape=original_shape, + device="cpu", + inplace=True, + ) + state_dict[pid] = sharded_state + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - param, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True, - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO): if master_to_working_map is not None and id(param) in master_to_working_map: working_param = master_to_working_map[id(param)] + elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: + working_param = optimizer.moe_master_to_working_map[id(param)] else: working_param = param @@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): prefix (str): Perfix of file to save size_per_shard (int): Max file size of each file shard that store state tensors """ + torch.cuda.empty_cache() assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") @@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}." ) + torch.cuda.empty_cache() def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 477b76547..8e6ea3884 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -67,7 +67,11 @@ class MLPExperts(nn.Module): self.ep_size = 1 if gated: - self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) + self.wi_gate = nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size + ) + ) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) else: self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index b768fb94a..2ac5b186d 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,6 +51,8 @@ class SparseMLP(nn.Module): hidden_size: int, intermediate_size: int, router_top_k: int = 1, + router_loss: bool = True, + router_norm: bool = False, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, @@ -65,15 +67,19 @@ class SparseMLP(nn.Module): enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_comm: bool = False, + return_gate_logits: bool = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts self.gated = mlp_gated + self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() + self.router_loss = router_loss + self.router_norm = router_norm # moe router noisy_func = get_noise_generator(router_noisy_policy, num_experts) @@ -150,9 +156,8 @@ class SparseMLP(nn.Module): tokens = inputs.reshape(-1, self.hidden_size) # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) + gate_logits = F.linear(tokens, self.gate_weight) + gate_output = gate_logits.to(torch.float) # update expert load if self.enable_load_balance == True: @@ -165,7 +170,12 @@ class SparseMLP(nn.Module): # the result from the router used_capacity, *route_result_list = self.router( - inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) + inputs=gate_output, + use_kernel=self.enable_kernel, + ep_group=self.ep_group, + use_loss=self.router_loss, + use_norm=self.router_norm, + ) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: @@ -177,22 +187,15 @@ class SparseMLP(nn.Module): # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process( - dispatch_data, - used_capacity, - overlap=self.enable_comm_overlap - ) + expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n" - "Please use Experts build function.") + raise NotImplementedError( + "This kind of communication has not been implemented yet.\n" "Please use Experts build function." + ) if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) @@ -204,7 +207,11 @@ class SparseMLP(nn.Module): ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - return ans + + if self.return_gate_logits: + return ans, gate_logits + else: + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) @@ -212,10 +219,7 @@ class SparseMLP(nn.Module): return expert_out def _ep_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ Expert Parallel @@ -228,10 +232,14 @@ class SparseMLP(nn.Module): """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: - expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_input = HierarchicalAllToAll.apply( + dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank + ) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) - expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) + expert_output = HierarchicalAllToAll.apply( + expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank + ) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] @@ -249,7 +257,7 @@ class SparseMLP(nn.Module): NUM_CHUNK = 4 NUM_STAGES = 4 - assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" + assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -262,13 +270,15 @@ class SparseMLP(nn.Module): for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() - output[:, :, offset:offset + chunk_size, :] = expert_out.data + output[:, :, offset : offset + chunk_size, :] = expert_out.data offset += chunk_size expert_out = None # all2all last output if _expert_out is not None: - expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) + expert_out = Capsule( + *AllToAll.apply(_expert_out.data, self.ep_group, True), + ) _expert_out = None # all2all next input @@ -288,10 +298,7 @@ class SparseMLP(nn.Module): return output def _tp_process( - self, - dispatch_data: torch.Tensor, - used_capacity: torch.Tensor, - overlap: bool = False + self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ without overlap: @@ -326,8 +333,9 @@ class SparseMLP(nn.Module): NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert ( + dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index f5815d05d..e40674c9b 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC): self._z_loss = None self.use_kernel = use_kernel - def get_capacity(self, logits_shape): + def get_capacity(self, num_tokens, num_experts, ep_group=None): + if ep_group is not None: + num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) + dist.all_reduce(num_tokens_tensor, group=ep_group) + num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 @@ -150,7 +154,14 @@ class Top1Router(MoeRouter): high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_loss: bool = False, + use_norm: bool = False, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -168,7 +179,8 @@ class Top1Router(MoeRouter): assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -207,7 +219,7 @@ class Top1Router(MoeRouter): weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() - return used_capacity, combine_weights, sec_mask + return used_capacity, combine_weights, sec_mask, probs class Top2Router(MoeRouter): @@ -240,7 +252,14 @@ class Top2Router(MoeRouter): drop_tks=drop_tks, ) - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: + def forward( + self, + inputs: torch.Tensor, + use_kernel: bool = False, + ep_group: Optional[ProcessGroup] = None, + use_norm: bool = False, + use_loss: bool = True, + ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). @@ -257,8 +276,13 @@ class Top2Router(MoeRouter): assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) + if use_norm: + routing_weights, _ = torch.topk(probs, 2, dim=-1) + probs = probs / routing_weights.sum(dim=-1, keepdim=True) + num_experts = probs.size(-1) - capacity = self.get_capacity(inputs.shape) + num_tokens = inputs.size(0) + capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) @@ -270,10 +294,11 @@ class Top2Router(MoeRouter): cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss - expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) - self.set_aux_loss(probs, expert_indices, num_experts) - self.set_z_loss(inputs) - self.pop_router_loss() + if use_loss: + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index e25e7dd48..c642f1a44 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable: return torch.nn.GELU() elif act == "swiglu": return SwiGLU + elif act == "silu": + return torch.nn.SiLU() else: raise NotImplementedError("Unsupported activation function") diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ba6c77056..5ac3c2b3a 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -26,3 +26,5 @@ class MoeParallelInfo: self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) + self.ep_rank = self.pg.coordinate(self.ep_axis) + self.dp_rank = self.pg.coordinate(self.dp_axis) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index e01c852be..a2433d1b2 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # because they have different parallel strategy # so we need to store them separately in param_groups # instead of working_groups - moe_params = list() + self.working_moe_params = list() # iterate over the param group in the optimizer # partition these param groups for data parallel training @@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if self.moe_extra_dp_pg is None: # skip moe param if is_moe_tensor(param): - moe_params.append(param) + self.working_moe_params.append(param) continue group_params.append(param) @@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # managed by this data parallel rank param_group["params"] = master_param_current_rank - # if there are moe params, store in additional group in optim - if len(moe_params) > 0: + # if there are moe params, store in addtional group in optim + if len(self.working_moe_params) > 0: + self._sync_master_param = False param_group = dict() + # create fp32 master param for key, value in self.optim.param_groups[0].items(): if key != "params": param_group[key] = value - param_group["params"] = moe_params + self.master_moe_params = [] + for param in self.working_moe_params: + self.master_moe_params.append(param.clone().to(torch.float32).detach()) + # create mapping from master to working for optimizer io + self.moe_master_to_working_map = {} + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param + # add to optim + param_group["params"] = self.master_moe_params self.optim.param_groups.append(param_group) # initialize communication stream for @@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # update the params in the optimizer self.optim.param_groups[group_id]["params"] = real_master_params[group_id] + # update param for moe ep + # move grad to master param and compute norm + if len(self.working_moe_params) > 0: + moe_grads = [] + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + if master_moe_param.grad is not None: + raise RuntimeError("Moe param should not have grad here") + grad = working_moe_param.grad + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(master_moe_param.dtype).to(master_moe_param.device) + master_moe_param.grad = grad + working_moe_param.grad = None + moe_grads.append(grad) + grad_partition_groups.append(grad) + norm_group = self._compute_grad_norm(gradients=moe_grads) + norm_groups.append(norm_group) + self.optim.param_groups[-1]["params"] = self.master_moe_params + del moe_grads + # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) - # TODO: we should store master param for ep - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.data = param.data.to(torch.float32) - param.grad = param.grad.to(torch.float32) - # update the parameters self.optim.step() - # release the moe gradm - if len(self.param_groups) > len(self._working_param_groups): - for param in self.param_groups[-1]["params"]: - param.grad = None - param.data = param.data.to(self._dtype) + # release moe grad + if len(self.working_moe_params) > 0: + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.grad = None + working_moe_param.data = ( + master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach() + ) # release the grad grad_partition_groups = [] @@ -885,9 +911,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) else: master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + if hasattr(self, "master_moe_params"): + for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params): + master_moe_param.copy_(working_moe_param) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: return self._param_store.working_to_master_param def get_master_to_working_map(self) -> Dict[int, torch.Tensor]: + if hasattr(self, "moe_master_to_working_map"): + return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map} return self._param_store.master_to_working_param diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 721a4796a..17b790e3e 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -1,13 +1,22 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.moe import SparseMLP from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_moe_epsize_param_dict +from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size + + +def delete_moe_info(model): + for _, param in model.named_parameters(): + if hasattr(param, "moe_info"): + delattr(param, "moe_info") class MoeModel(nn.Module): @@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None): for i in range(world_size - 1): a = tensor_list[i] b = tensor_list[i + 1] - assert not torch.allclose(a, b), \ - (f"expected tensors on rank {i} and {i + 1} not to be equal " - f"but they are, {a} vs {b}") + assert not torch.allclose(a, b), ( + f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}" + ) + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss) + else: + loss.backward() + return y + + +def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: + """Sync the parameters of tp model from ep model + + Args: + local_model (MoeModule) + ep_model (MoeModule) + """ + for (local_name, local_param), (ep_name, ep_param) in zip( + local_model.named_parameters(), ep_model.named_parameters() + ): + assert local_name in ep_name, print(f"{local_name} != {ep_name}") + if "experts" not in local_name: + if assert_grad_flag: + assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}" + assert torch.allclose(local_param.grad, ep_param.grad) + else: + local_param.data.copy_(ep_param.data) + continue + + # gather param from ep model + param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param)) + all_param = torch.cat(param_list, dim=0) + if assert_grad_flag: + grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))] + dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param)) + all_grad = torch.cat(grad_list, dim=0) + + if assert_grad_flag: + assert torch.allclose(local_param, all_param) + assert torch.allclose(local_param.grad, all_grad) + else: + local_param.data.copy_(all_param.data) + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) + + assert_close(a, b, rtol=rtol, atol=atol) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 8f51e1663..d6dad2d7f 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -12,7 +12,6 @@ import colossalai from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn sys.path.append( @@ -95,6 +94,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=1, zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -103,6 +103,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=dist.get_world_size(), zero_stage=2, custom_policy=OpenMoeForCausalLMPolicy(), ) @@ -111,6 +112,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=1, + ep_size=2, zero_stage=2, extra_dp_size=2, custom_policy=OpenMoeForCausalLMPolicy(), @@ -120,6 +122,7 @@ def get_model(parallel): precision="bf16", tp_size=1, pp_size=2, + ep_size=2, zero_stage=1, microbatch_size=1, custom_policy=OpenMoeForCausalLMPolicy(), @@ -130,27 +133,6 @@ def get_model(parallel): def _test_moe_checkpoint(rank, parallel): - if parallel == None: - MOE_MANAGER.setup( - parallel=None, - ) - elif parallel == "ep": - MOE_MANAGER.setup( - parallel="EP", - ) - elif parallel == "ep_zero": - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=2, - ) - elif parallel == "hybrid": - MOE_MANAGER.setup( - parallel="EP", - mode="fixed", - fixed_dp_size=1, - fixed_ep_size=2, - fixed_pp_size=2, - ) model1, booster1, optim1 = get_model(parallel) model2, booster2, optim2 = get_model(parallel) model3, booster3, optim3 = get_model(parallel) @@ -207,6 +189,7 @@ def _run_dist(rank, world_size, port, parallel): _test_moe_checkpoint(rank, parallel) +@pytest.mark.skip(reason="This is tested in ColossalMOE") @pytest.mark.dist @pytest.mark.parametrize("world_size", [4]) @pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"]) diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py index 7ba7fa6f6..9f6167692 100644 --- a/tests/test_moe/test_moe_router.py +++ b/tests/test_moe/test_moe_router.py @@ -4,15 +4,21 @@ import torch from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter -@pytest.mark.parametrize(["router", "num_groups"], [ - (Top1Router(), 1), - (Top2Router(), 1), - # (TopKRouter(num_selected_experts=3), 4), -]) -@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [ - (4, 5, 8), - (3, 4, 4), -]) +@pytest.mark.parametrize( + ["router", "num_groups"], + [ + (Top1Router(), 1), + (Top2Router(), 1), + # (TopKRouter(num_selected_experts=3), 4), + ], +) +@pytest.mark.parametrize( + ["batch_size", "seq_len", "num_experts"], + [ + (4, 5, 8), + (3, 4, 4), + ], +) def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int): x = torch.randn((batch_size * seq_len, num_experts)).cuda() if num_groups > 1: @@ -20,18 +26,18 @@ def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_ex router.train() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) router.eval() if isinstance(router, TopKRouter): - _, combine_array, dispatch_mask = router(x, expert_capacity=2) + combine_array, dispatch_mask = router(x, expert_capacity=2) else: - _, combine_array, dispatch_mask = router(x) + combine_array, dispatch_mask = router(x)[1:3] assert combine_array.shape[:-1] == x.shape assert dispatch_mask.shape[:-1] == x.shape assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index f0795a4c7..1bff21066 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -4,102 +4,75 @@ import torch import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters()) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) + sync_local_from_ep(zero_model, moe_model) - # assert zero model - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.module.named_parameters() - ): - assert zero_name == torch_name - assert torch.allclose(zero_param.data, torch_param.data) - - data = torch.randn(16, 4).cuda() + data = torch.randn(16, 4).bfloat16().cuda() label = torch.randint(0, 4, (16,)).cuda() - torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) - zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) - assert torch.allclose(torch_out, zero_out) - grad_handler.handle_gradient() + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + assert torch.allclose(zero_out, moe_out) - for (zero_name, zero_param), (torch_name, torch_param) in zip( - zero_model.module.named_parameters(), torch_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.module.named_parameters(), zero_model.module.named_parameters() ): - assert zero_name == torch_name - zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) - if hasattr(zero_param, "moe_info"): - assert len(zero_grad_list) == 0 - assert torch.allclose(zero_param.grad, torch_param.grad) + assert moe_name == zero_name + moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param)) + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) + if hasattr(moe_param, "moe_info"): + assert len(moe_grad_list) == 0 + if stage == 1: + zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape) + else: + zero_grad = zero_grad_list[0].view(moe_param.grad.shape) + assert torch.allclose( + moe_param.grad, zero_grad, atol=1e-5 + ), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}" else: - assert len(zero_grad_list) > 0 - torch_grad_list = split_ddp_grad(torch_param.grad, world_size) - if stage == 2: - torch_grad_list = torch_grad_list[local_rank : local_rank + 1] - assert len(zero_grad_list) == len(torch_grad_list) - for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): - assert torch.allclose(zero_grad, torch_grad) + assert len(moe_grad_list) > 0 + assert len(moe_grad_list) == len(zero_grad_list) + for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list): + assert torch.allclose(moe_grad, zero_grad) -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank) - run_zero_test(rank, world_size, stage=1) - run_zero_test(rank, world_size, stage=2) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) +def test_moe_zero_model(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_model(world_size=2) + test_moe_zero_model(world_size=2, stage=1) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 0d2e2fb1b..4f6067aaa 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -4,89 +4,80 @@ import torch import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn -from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel +from colossalai.testing.random import seed_all +from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep -def split_ddp_grad(grad, world_size): - with torch.no_grad(): - grad = grad.clone().detach().flatten() - padding_size = (world_size - grad.numel() % world_size) % world_size - if padding_size > 0: - grad = torch.nn.functional.pad(grad, [0, padding_size]) - splited_grad = grad.split(grad.numel() // world_size) - return splited_grad - - -def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - - if isinstance(model, LowLevelZeroModel): - optimizer.backward(loss) - else: - loss.backward() - return y - - -def run_zero_optim_test(local_rank, world_size, stage=1): +def run_zero_test(local_rank, stage=1): criterion = torch.nn.CrossEntropyLoss() - zero_model = MoeModel() - zero_optimizer = torch.optim.Adam(zero_model.parameters()) - plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") - booster = Booster(plugin=plugin) - zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel="EP") + moe_model = MoeModel().bfloat16() + moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0) + moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + moe_booster = Booster(plugin=moe_plugin) + moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer) - torch_model = MoeModel() - for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): - torch_param.data.copy_(zero_param.data) - torch_optimizer = torch.optim.Adam(torch_model.parameters()) - torch_model = torch_model.cuda() - grad_handler = MoeGradientHandler(torch_model) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(parallel=None) + zero_model = MoeModel().bfloat16() + delete_moe_info(zero_model) + sync_local_from_ep(zero_model, moe_model) + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0) + zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16") + zero_booster = Booster(plugin=zero_plugin) + zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer) - for _ in range(2): - data = torch.randn(16, 4).cuda() / (local_rank + 1) - label = torch.randint(0, 4, (16,)).cuda() - run_fwd_bwd(torch_model, data, label, criterion, None) - run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) - grad_handler.handle_gradient() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() + ): + if ".experts." in moe_name: + continue + assert moe_name == zero_name + assert torch.allclose( + moe_param.data, zero_param.data + ), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}" - torch_optimizer.step() + for _ in range(1): + data = torch.randn(2, 4).bfloat16().cuda() + label = torch.randint(0, 4, (2,)).cuda() + + moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer) + zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + assert torch.allclose(zero_out, moe_out) + moe_optimizer.step() zero_optimizer.step() - for (torch_name, torch_param), (zero_name, zero_param) in zip( - torch_model.named_parameters(), zero_model.named_parameters() + for (moe_name, moe_param), (zero_name, zero_param) in zip( + moe_model.named_parameters(), zero_model.named_parameters() ): - assert torch.allclose( - torch_param.data, zero_param.data - ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + assert moe_name == zero_name + if is_moe_tensor(moe_param): + param_size = moe_param.shape[0] + zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size] + loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype) - torch_optimizer.zero_grad() + moe_optimizer.zero_grad() zero_optimizer.zero_grad() -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, port, stage): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - MOE_MANAGER.setup(parallel="EP") - run_zero_optim_test(rank, world_size, stage=1) - run_zero_optim_test(rank, world_size, stage=2) + seed_all(42 + rank) + run_zero_test(rank, stage=stage) @pytest.mark.dist @pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("stage", [1, 2]) @rerun_if_address_is_in_use() -def test_moe_zero_optim(world_size): - spawn(run_dist, world_size) +def test_moe_zero_optim(world_size, stage): + spawn(run_dist, world_size, stage=stage) if __name__ == "__main__": - test_moe_zero_optim(world_size=2) + test_moe_zero_optim(world_size=2, stage=1)