mirror of https://github.com/hpcaitech/ColossalAI
[moe] support mixtral (#5309)
* [moe] add mixtral block for single expert * [moe] mixtral block fwd support uneven ep * [moe] mixtral block bwd support uneven ep * [moe] add mixtral moe layer * [moe] simplify replace * [meo] support save sharded mixtral * [meo] support load sharded mixtral * [meo] support save sharded optim * [meo] integrate moe manager into plug * [meo] fix optimizer load * [meo] fix mixtral layerpull/5372/head
parent
c904d2ae99
commit
da39d21b71
|
@ -1,205 +1,617 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
|
||||
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
StateDictSharder,
|
||||
gather_distributed_param,
|
||||
get_model_base_filenames,
|
||||
get_optimizer_base_filenames,
|
||||
load_shard_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_config_file,
|
||||
save_param_groups,
|
||||
save_state_dict_shards,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
||||
def __init__(
|
||||
self,
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__(dp_group, pp_group, tp_group, zero_stage, verbose)
|
||||
moe_info = MOE_MANAGER.parallel_info_dict[MOE_MANAGER.ep_size]
|
||||
self.ep_group = moe_info.ep_group
|
||||
self.ep_size = moe_info.ep_size
|
||||
self.ep_rank = moe_info.ep_rank
|
||||
self.real_dp_rank = moe_info.dp_rank
|
||||
|
||||
@torch.no_grad()
|
||||
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
model: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
param_name_pattern: Optional[str] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
|
||||
# Save parameters.
|
||||
for name, param in model.named_parameters():
|
||||
if param is None:
|
||||
continue
|
||||
if param_name_pattern is not None and param_name_pattern not in name:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save buffers.
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Save extra states.
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||
"""
|
||||
model_param_dict = dict(model.named_parameters())
|
||||
for name, param in list(state_dict.items()):
|
||||
if ".gate.weight" in name:
|
||||
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
|
||||
state_dict[new_name] = state_dict.pop(name)
|
||||
elif ".experts." in name:
|
||||
# if is moe tensor
|
||||
# in our moe module, expert is cat as one tensor
|
||||
# but mixtral's experts is not cat
|
||||
# we will insert the loaded expert into the position of cat tensor
|
||||
Save sharded model checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
|
||||
- Multiple files that store state tensors of models.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
|
||||
|
||||
# get model param
|
||||
str_idx = name.index(".experts.")
|
||||
expert_idx = int(name.split(".")[-3])
|
||||
if ".w1." in name:
|
||||
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
|
||||
elif ".w2." in name:
|
||||
model_param_name = name.replace(name[str_idx:], ".experts.wo")
|
||||
elif ".w3." in name:
|
||||
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
|
||||
model_param_name = "module." + model_param_name
|
||||
# skip for pipeline
|
||||
if model_param_name not in model_param_dict:
|
||||
continue
|
||||
model_param = model_param_dict[model_param_name]
|
||||
assert is_moe_tensor(model_param)
|
||||
# get expert range
|
||||
ep_rank = get_ep_rank(model_param)
|
||||
ep_size = get_ep_size(model_param)
|
||||
expert_num = 8 // ep_size
|
||||
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
|
||||
# insert new param
|
||||
if expert_idx in expert_range:
|
||||
new_param = model_param
|
||||
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
|
||||
state_dict[model_param_name] = new_param
|
||||
state_dict.pop(name)
|
||||
else:
|
||||
new_name = "module." + name
|
||||
state_dict[new_name] = state_dict.pop(name)
|
||||
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
model (nn.Module): Model on local device to be saved.
|
||||
checkpoint (str): Checkpointing path which should be a directory path.
|
||||
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
|
||||
prefix (str, optional): Perfix of file to save. Defaults to None.
|
||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.real_dp_rank != 0:
|
||||
return
|
||||
|
||||
# ep_rank 0 saves all the parameters and buffers.
|
||||
# other ep_ranks save only experts
|
||||
ep_param_pattern = "experts." if self.ep_rank != 0 else None
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
|
||||
weights_name = weights_name.replace(
|
||||
".safetensors", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.safetensors"
|
||||
)
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
dist.barrier(self.ep_group)
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for weight, weight_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(weight, weight_filename)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
save_config_file(model, checkpoint)
|
||||
rmtree(tmp_index_file_folder)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
param: torch.Tensor,
|
||||
original_shape: torch.Size,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
|
||||
Args:
|
||||
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
|
||||
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
dp_group (ProcessGroup): The process group of data parallel.
|
||||
tp_group (ProcessGroup): The process group of tensor parallel.
|
||||
use_zero (bool): Whether Zero is used.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
|
||||
|
||||
Returns:
|
||||
OrderedDict: The complete optimizer state of given parameter.
|
||||
"""
|
||||
dp_size = dist.get_world_size(dp_group)
|
||||
tp_size = dist.get_world_size(tp_group)
|
||||
current_shape = param.shape
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero and not is_moe_param:
|
||||
v = v.cuda()
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||
|
||||
# Then gather TP shards.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
||||
if partition_dim is not None:
|
||||
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
@staticmethod
|
||||
def _optimizer_sharder(
|
||||
optimizer: OptimizerWrapper,
|
||||
use_zero: bool,
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
only_moe_param: bool = False,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
state_dict_sharder = StateDictSharder(size_per_shard)
|
||||
param_info = optimizer.param_info
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
|
||||
for param, state in optimizer.optim.state.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = MixtralMoEHybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
original_shape=original_shape,
|
||||
dp_group=dp_group,
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
)
|
||||
|
||||
if only_moe_param and not is_moe_tensor(working_param):
|
||||
continue
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
|
||||
# Return the last block in sharder.
|
||||
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
The following files will be created under the path:
|
||||
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
|
||||
- A group file (pytorch_optim_group.bin) recording information of param_groups
|
||||
- Multiple files that store state tensors of optimizers.
|
||||
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
|
||||
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
|
||||
checkpoint (str): Path to save optimizer state_dict
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||
# In this case only let the device with dp_rank == 0 save the model.
|
||||
if not self.use_zero and self.real_dp_rank != 0:
|
||||
return
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
state_dict_shard = MixtralMoEHybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
only_moe_param=self.ep_rank != 0,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.real_dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1 and self.ep_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
# Store index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The optimizer is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}-shard.bin")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}-{self.ep_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
else:
|
||||
return
|
||||
|
||||
dist.barrier(self.pp_group)
|
||||
dist.barrier(self.ep_group)
|
||||
|
||||
# The global master rank integrates the index files and clean the folder.
|
||||
if self.coordinator.is_master():
|
||||
final_index_file = CheckpointIndexFile(checkpoint)
|
||||
final_index_file.append_meta_data("total_size", 0)
|
||||
|
||||
for filename in os.listdir(tmp_index_file_folder):
|
||||
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
|
||||
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
|
||||
for param_id, state_filename in stage_index_file.weight_map.items():
|
||||
final_index_file.append_weight_map(param_id, state_filename)
|
||||
|
||||
# Store param groups.
|
||||
final_index_file.append_meta_data("param_groups", param_group_file)
|
||||
group_file_path = os.path.join(checkpoint, param_group_file)
|
||||
save_param_groups(optimizer.param_info, group_file_path)
|
||||
|
||||
final_index_file.write_index_file(final_index_file_path)
|
||||
rmtree(tmp_index_file_folder)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
|
||||
"""
|
||||
Load sharded optimizer with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
optimizer (OptimizerWrapper): The optimizer to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
):
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
return optimizer.param_info["param2id"][id(working_param)]
|
||||
|
||||
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
|
||||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Load param_groups
|
||||
param_group_path = ckpt_index_file.get_param_group_filename()
|
||||
if param_group_path is None:
|
||||
raise RuntimeError(
|
||||
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
|
||||
Lacking param group file under current directory."
|
||||
)
|
||||
saved_groups = torch.load(param_group_path)
|
||||
|
||||
updated_groups = []
|
||||
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
|
||||
# obtain updated param group
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||
updated_groups.append(new_pg)
|
||||
# ep param groups
|
||||
if len(optimizer.optim.param_groups) == len(saved_groups) + 1:
|
||||
new_pg = copy.deepcopy(saved_pg)
|
||||
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||
updated_groups.append(new_pg)
|
||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||
|
||||
# Load saved states to optimizer.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||
filename = weight_map[name]
|
||||
# If this param's states has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
continue
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
state_dict = self.pre_load_model(model, state_dict)
|
||||
missing_keys = []
|
||||
|
||||
load_state_dict_into_model(
|
||||
model,
|
||||
state_dict,
|
||||
missing_keys=missing_keys,
|
||||
strict=strict,
|
||||
load_sub_module=True,
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.shard_from_complete_optimizer_state(
|
||||
state,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
is_moe_param=is_moe_tensor(working_param),
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
name = name.replace("module.", "")
|
||||
name = name.replace(".gate_weight", ".gate.weight")
|
||||
if ".experts.wi_gate" in name:
|
||||
for i in range(8):
|
||||
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||
_load(new_name)
|
||||
elif ".experts.wi_up" in name:
|
||||
for i in range(8):
|
||||
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||
_load(new_name)
|
||||
elif ".experts.wo" in name:
|
||||
for i in range(8):
|
||||
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||
_load(new_name)
|
||||
else:
|
||||
_load(name)
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
def shard_from_complete_optimizer_state(
|
||||
self,
|
||||
state: OrderedDict,
|
||||
current_shape: torch.Size,
|
||||
original_shape: torch.Size,
|
||||
device: torch.device,
|
||||
inplace: bool,
|
||||
is_moe_param: bool,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With complete optimizer states of a specific parameter loaded from checkpoint,
|
||||
slice out the sharded optimizer states kept by current device.
|
||||
|
||||
@torch.no_grad()
|
||||
def pre_save_model(self, model: nn.Module) -> dict:
|
||||
torch.cuda.empty_cache()
|
||||
state_dict = model.state_dict()
|
||||
for name, param in list(model.named_parameters()):
|
||||
if ".gate_weight" in name:
|
||||
new_name = name.replace(".gate_weight", ".gate.weight")
|
||||
state_dict[new_name] = state_dict.pop(name).cpu()
|
||||
elif ".experts." in name:
|
||||
ep_group = get_ep_group(param)
|
||||
ep_rank = get_ep_rank(param)
|
||||
ep_size = get_ep_size(param)
|
||||
dp_rank = get_dp_rank(param)
|
||||
Args:
|
||||
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
|
||||
current_shape (torch.Size): The size of parameter after sharding.
|
||||
original_shape (torch.Size): The size of parameter before sharding.
|
||||
device (torch.device): The destination device of loaded optimizer states.
|
||||
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
|
||||
|
||||
if dp_rank == 0:
|
||||
param = param.data.cuda()
|
||||
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||
# gather param from every ep rank
|
||||
dist.all_gather(all_param, param, group=ep_group)
|
||||
if ep_rank == 0:
|
||||
all_param = torch.cat(all_param, dim=0)
|
||||
assert all_param.shape[0] == 8
|
||||
for i in range(8):
|
||||
if ".wi_gate" in name:
|
||||
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||
elif ".wi_up" in name:
|
||||
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||
elif ".wo" in name:
|
||||
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||
new_name = new_name.replace("module.", "")
|
||||
new_param = all_param[i].transpose(-1, -2)
|
||||
state_dict[new_name] = new_param.cpu()
|
||||
state_dict.pop(name)
|
||||
else:
|
||||
state_dict[name] = param.cpu()
|
||||
Returns:
|
||||
OrderedDict: The sharded optimizer state of the given parameter.
|
||||
"""
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for name, param in list(state_dict.items()):
|
||||
new_name = name.replace("module.", "")
|
||||
state_dict[new_name] = state_dict.pop(name)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if self.pp_size > 1:
|
||||
if self.dp_rank == 0:
|
||||
# gather state_dict from every pp rank
|
||||
# because ckpt is large, we split it into 10 parts
|
||||
# and gather them one by one
|
||||
new_state_dict = {}
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
gap_key_num = min(30, len(state_dict_keys))
|
||||
gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num
|
||||
for i in range(gap_key_num):
|
||||
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
|
||||
cur_state_dict = {}
|
||||
for k in cur_keys:
|
||||
cur_state_dict[k] = state_dict[k]
|
||||
out = [None for _ in range(self.pp_size)]
|
||||
dist.all_gather_object(out, cur_state_dict, group=self.pp_group)
|
||||
if self.pp_rank == 0:
|
||||
for o in out:
|
||||
for k, v in o.items():
|
||||
new_state_dict[k] = v.cpu()
|
||||
state_dict = new_state_dict
|
||||
dist.barrier()
|
||||
return state_dict
|
||||
for k, v in state_.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
if partition_dim is not None:
|
||||
slice_size = current_shape[partition_dim]
|
||||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
||||
# Shard state along data parallel group when using Zero.
|
||||
if self.use_zero and not is_moe_param:
|
||||
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
|
||||
with torch.no_grad():
|
||||
v = v.flatten()
|
||||
if padding_size > 0:
|
||||
v = torch.nn.functional.pad(v, [0, padding_size])
|
||||
slice_size = v.numel() // self.dp_size
|
||||
v = v.split(slice_size, dim=0)[self.dp_rank]
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, strict: bool = False):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,80 +1,92 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
|
||||
|
||||
|
||||
class MixtralSparseMLP:
|
||||
r"""
|
||||
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||
"""
|
||||
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.setup_ep()
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"FusedLayerNorm is not implemented as a physical class. "
|
||||
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
||||
)
|
||||
def setup_ep(self):
|
||||
_, moe_info = MOE_MANAGER.get_info(self.num_experts)
|
||||
ep_group = moe_info.ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
set_moe_tensor_info(p, moe_info)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module:
|
||||
r"""
|
||||
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
||||
and optionally marking parameters for gradient aggregation.
|
||||
def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
|
||||
LazyInitContext.materialize(module)
|
||||
module.__class__ = EPMixtralSparseMoeBlock
|
||||
module.setup_ep()
|
||||
return module
|
||||
|
||||
Args:
|
||||
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
Returns:
|
||||
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
Raises:
|
||||
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
LazyInitContext.materialize(module)
|
||||
selected_experts = selected_experts.t().reshape(-1)
|
||||
selected_experts_idx = selected_experts.argsort()
|
||||
dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
|
||||
input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
# get the attributes of the module
|
||||
moe_kwargs = dict(
|
||||
num_experts=8,
|
||||
hidden_size=module.hidden_dim,
|
||||
intermediate_size=module.ffn_dim,
|
||||
router_top_k=module.top_k,
|
||||
router_norm=True,
|
||||
router_loss=False,
|
||||
# router_capacity_factor_train=
|
||||
# router_capacity_factor_eval=
|
||||
mlp_activation="silu",
|
||||
mlp_gated=True,
|
||||
# enable_load_balance=
|
||||
# load_balance_tolerance=
|
||||
# load_balance_beam_width=
|
||||
# load_balance_group_swap_factor=
|
||||
enable_kernel=enable_kernel,
|
||||
# enable_comm_overlap=
|
||||
# enable_hierarchical_comm=
|
||||
return_gate_logits=True,
|
||||
)
|
||||
dtype = module.gate.weight.dtype
|
||||
device = module.gate.weight.device
|
||||
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
||||
|
||||
return sparse_mlp
|
||||
|
||||
|
||||
def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module:
|
||||
"""
|
||||
Reverse the replace layer operation
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): The object of layer to shard
|
||||
"""
|
||||
if isinstance(model, MixtralDecoderLayer):
|
||||
model.block_sparse_moe = MixtralSparseMLP.from_native_module(
|
||||
model.block_sparse_moe, enable_kernel=enable_kernel
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
# compute expert output
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
# no need to split
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
|
||||
output_states = expert.w2(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0:
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
|
||||
split_states = expert.w2(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_experts_idx = torch.empty_like(selected_experts_idx)
|
||||
recover_experts_idx[selected_experts_idx] = torch.arange(
|
||||
selected_experts_idx.size(0), device=selected_experts_idx.device
|
||||
)
|
||||
else:
|
||||
for _, child in model.named_children():
|
||||
replace_moe_layer(child, enable_kernel)
|
||||
dispatch_states = dispatch_states[recover_experts_idx]
|
||||
k_hidden_states = dispatch_states.chunk(self.top_k)
|
||||
output_states = k_hidden_states[0] * routing_weights[:, 0, None]
|
||||
for i in range(1, self.top_k):
|
||||
output_states += k_hidden_states[i] * routing_weights[:, i, None]
|
||||
output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return output_states, router_logits
|
||||
|
|
|
@ -20,6 +20,8 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
|||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from .mixtral_layer import EPMixtralSparseMoeBlock
|
||||
|
||||
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||
|
||||
|
||||
|
@ -51,6 +53,18 @@ class MixtralPolicy(Policy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="block_sparse_moe",
|
||||
target_module=EPMixtralSparseMoeBlock,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=MixtralDecoderLayer,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
|
|
@ -3,7 +3,6 @@ import os
|
|||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
@ -15,23 +14,6 @@ def move_to_cuda(batch, device):
|
|||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
|
||||
# pytorch ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
|
||||
# saved ckpt
|
||||
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||
# download
|
||||
else:
|
||||
ckpt_path = snapshot_download(ckpt_path)
|
||||
booster.load_model(model, ckpt_path)
|
||||
if optimizer is not None:
|
||||
optimizer.sync_moe_master_param()
|
||||
optimizer.update_master_params(model)
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
|
@ -90,7 +72,7 @@ def load_checkpoint(
|
|||
"""
|
||||
|
||||
# Update booster params states.
|
||||
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
|
||||
booster.load_model(model, os.path.join(load_dir, "modeling"))
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
|
|
|
@ -2,10 +2,8 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_model
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
|
@ -13,9 +11,6 @@ import colossalai
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -30,16 +25,10 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
default="ep",
|
||||
choices=["ep"],
|
||||
help="Parallel methos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./outputs",
|
||||
help="The path of your saved model after finetuning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
|
@ -71,60 +60,38 @@ def main():
|
|||
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
ep_size = min(dist.get_world_size(), config.num_local_experts)
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": MixtralForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": args.precision,
|
||||
"checkpoint_io": MixtralMoECheckpointIO,
|
||||
"zero_stage": 1,
|
||||
}
|
||||
mgr_dict = {}
|
||||
if args.plugin == "ep":
|
||||
dp_size = dist.get_world_size()
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
max_ep_size=dp_size,
|
||||
**mgr_dict,
|
||||
ep_size=ep_size,
|
||||
zero_stage=1,
|
||||
precision=args.precision,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build mixtral model
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
config.num_local_experts = 1 # dont change this. it will not affect model
|
||||
with skip_init():
|
||||
model = MixtralForCausalLM(config)
|
||||
model.num_experts = 8
|
||||
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||
model = model.to(get_current_device())
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
|
||||
# Replace moe
|
||||
with skip_init():
|
||||
replace_moe_layer(model)
|
||||
model.eval()
|
||||
coordinator.print_on_master(f"Finish replace moe module")
|
||||
model = MixtralForCausalLM.from_pretrained(args.model_name)
|
||||
coordinator.print_on_master(f"Finish load model")
|
||||
|
||||
# Prepare tokenizer and dataloader
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, _, _, _, _ = booster.boost(model=model)
|
||||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# load ckpt
|
||||
load_model(args.model_name, model, booster)
|
||||
coordinator.print_on_master(f"Finish load ckpt")
|
||||
model.eval()
|
||||
|
||||
if coordinator.rank == 0:
|
||||
text = ["Hello my name is"]
|
||||
|
@ -132,10 +99,13 @@ def main():
|
|||
text = ["What's the largest country in the world?", "How many people live in China?", "帮我续写这首诗:离离原上草"]
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||
outputs = tokenizer.batch_decode(outputs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
print(f"[{coordinator.rank}] {outputs}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_layer import EPMixtralSparseMoeBlock
|
||||
from torch.testing import assert_close
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import colossalai
|
||||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP", mode="fixed", fixed_dp_size=1, fixed_ep_size=dist.get_world_size(), fixed_pp_size=1
|
||||
)
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
orig_model = MixtralSparseMoeBlock(config).cuda()
|
||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output, orig_logits = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPMixtralSparseMoeBlock.from_native_module(model)
|
||||
ep_output, ep_logits = model(x)
|
||||
assert_close(orig_logits, ep_logits)
|
||||
assert_close(orig_output, ep_output)
|
||||
orig_loss = orig_output.mean()
|
||||
orig_loss.backward()
|
||||
ep_loss = ep_output.mean()
|
||||
ep_loss.backward()
|
||||
assert_close(orig_loss, ep_loss)
|
||||
name_to_p = {n: p for n, p in orig_model.named_parameters()}
|
||||
for n, ep_p in model.named_parameters():
|
||||
p = name_to_p[n]
|
||||
if ep_p.grad is not None:
|
||||
assert_close(p.grad, ep_p.grad)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("world_size", [2, 4])
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mixtral_moe_layer(2)
|
|
@ -1,185 +1,144 @@
|
|||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
from torch.optim import Adam
|
||||
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
def check_model_equal(model1, model2):
|
||||
assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
|
||||
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||
assert torch.equal(p1.half(), p2.half())
|
||||
|
||||
|
||||
def get_optimizer_snapshot(optim):
|
||||
state = {id(k): deepcopy(v) for k, v in optim.state.items()}
|
||||
param_groups = []
|
||||
for group in optim.param_groups:
|
||||
params = [id(p) for p in group["params"]]
|
||||
new_group = {"params": params}
|
||||
for k, v in group.items():
|
||||
if k != "params":
|
||||
new_group[k] = v
|
||||
param_groups.append(new_group)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
"state": state,
|
||||
"param_groups": param_groups,
|
||||
}
|
||||
|
||||
|
||||
def run_fwd_bwd(
|
||||
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
|
||||
):
|
||||
model.train()
|
||||
if pipeline:
|
||||
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
|
||||
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
|
||||
y = booster.execute_pipeline(
|
||||
train_dataloader_iter,
|
||||
model,
|
||||
lambda x, y: x.loss,
|
||||
optimizer,
|
||||
return_loss=True,
|
||||
return_outputs=True,
|
||||
)
|
||||
# Backward and optimize
|
||||
if is_pp_last_stage:
|
||||
loss = y["loss"]
|
||||
else:
|
||||
if criterion:
|
||||
y = model(data).logits
|
||||
loss = criterion(y)
|
||||
else:
|
||||
loss = model(data, label)
|
||||
loss = loss.float()
|
||||
|
||||
if optimizer is not None:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
return y
|
||||
def check_optimizer_snapshot_equal(snapshot1, snapshot2):
|
||||
# check param_groups
|
||||
assert len(snapshot1["param_groups"]) == len(snapshot2["param_groups"])
|
||||
for group1, group2 in zip(snapshot1["param_groups"], snapshot2["param_groups"]):
|
||||
assert set(group1.keys()) == set(group2.keys())
|
||||
for k in group1.keys():
|
||||
assert group1[k] == group2[k]
|
||||
# check state
|
||||
assert set(snapshot1["state"].keys()) == set(
|
||||
snapshot2["state"].keys()
|
||||
), f"{snapshot1['state'].keys()}, {snapshot2['state'].keys()}"
|
||||
for pid in snapshot1["state"].keys():
|
||||
state1, state2 = snapshot1["state"][pid], snapshot2["state"][pid]
|
||||
assert set(state1.keys()) == set(state2.keys())
|
||||
for k in state1.keys():
|
||||
if isinstance(state1[k], torch.Tensor):
|
||||
assert torch.equal(state1[k], state2[k]), f"{k}, {state1[k]}, {state2[k]}"
|
||||
else:
|
||||
assert state1[k] == state2[k]
|
||||
|
||||
|
||||
def get_config():
|
||||
def check_mixtral_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
config = MixtralConfig(
|
||||
vocab_size=300,
|
||||
hidden_size=32,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
dropout_rate=0.0,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def get_model(parallel):
|
||||
config = get_config()
|
||||
model = MixtralForCausalLM(config).to(torch.bfloat16)
|
||||
replace_moe_layer(model)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
args = dict(
|
||||
precision="bf16",
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
zero_stage=1,
|
||||
pp_size=2,
|
||||
ep_size=2,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
checkpoint_io=MixtralMoECheckpointIO,
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
microbatch_size=1,
|
||||
zero_stage=1,
|
||||
)
|
||||
if parallel == "ep":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=1,
|
||||
**args,
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
pp_size=2,
|
||||
microbatch_size=1,
|
||||
**args,
|
||||
)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
|
||||
return model, booster, optim
|
||||
|
||||
|
||||
def _test_moe_checkpoint(parallel):
|
||||
if dist.get_rank() == 0:
|
||||
if os.path.exists("./tmp_ckpt1"):
|
||||
shutil.rmtree("./tmp_ckpt1")
|
||||
if os.path.exists("./tmp_ckpt2"):
|
||||
shutil.rmtree("./tmp_ckpt2")
|
||||
dist.barrier()
|
||||
|
||||
if parallel == None:
|
||||
MOE_MANAGER.setup(
|
||||
parallel=None,
|
||||
)
|
||||
elif parallel == "ep":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
)
|
||||
elif parallel == "hybrid":
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=1,
|
||||
fixed_ep_size=2,
|
||||
fixed_pp_size=2,
|
||||
)
|
||||
model1, booster1, optim1 = get_model(parallel)
|
||||
model2, booster2, optim2 = get_model(parallel)
|
||||
# param ckpt
|
||||
# check not equal
|
||||
try:
|
||||
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||
raise AssertionError("state_dict should not be equal")
|
||||
except:
|
||||
pass
|
||||
# shard
|
||||
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
||||
booster2.load_model(model2, "./tmp_ckpt1")
|
||||
# check
|
||||
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||
|
||||
# optim ckpt
|
||||
criterion = lambda x: x.mean()
|
||||
data = torch.randint(0, 4, (2, 4)).cuda()
|
||||
label = torch.randint(0, 4, (2,)).cuda()
|
||||
if parallel == "hybrid":
|
||||
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
|
||||
else:
|
||||
kwargs = {}
|
||||
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
|
||||
optim1.step()
|
||||
optim1.zero_grad()
|
||||
# shard
|
||||
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
|
||||
dist.barrier()
|
||||
booster2.load_optimizer(optim2, "./tmp_ckpt2")
|
||||
# check
|
||||
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree("./tmp_ckpt1")
|
||||
shutil.rmtree("./tmp_ckpt2")
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port, parallel):
|
||||
colossalai.launch(
|
||||
config=dict(),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host="localhost",
|
||||
port=port,
|
||||
backend="nccl",
|
||||
model, optimizer, *_ = booster.boost(model=model, optimizer=optimizer)
|
||||
# initialize grads
|
||||
data_iter = iter(
|
||||
[{"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": input_ids.clone()}]
|
||||
)
|
||||
_test_moe_checkpoint(parallel)
|
||||
booster.execute_pipeline(
|
||||
data_iter,
|
||||
model,
|
||||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
# check save model
|
||||
booster.save_model(model, "mixtral_model", shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained("mixtral_model").cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
saved_model.save_pretrained("mixtral_hf_model")
|
||||
dist.barrier()
|
||||
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, "mixtral_hf_model")
|
||||
check_model_equal(model, new_model)
|
||||
|
||||
# check save optimizer
|
||||
optimizer.step()
|
||||
snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
booster.save_optimizer(optimizer, "mixtral_optim", shard=True)
|
||||
dist.barrier()
|
||||
# reset optimizer state
|
||||
for state in optimizer.unwrap().state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v.zero_()
|
||||
booster.load_optimizer(optimizer, "mixtral_optim")
|
||||
loaded_snapshot = get_optimizer_snapshot(optimizer.unwrap())
|
||||
check_optimizer_snapshot_equal(snapshot, loaded_snapshot)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch({}, rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [4])
|
||||
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_moe_checkpoint(world_size, parallel):
|
||||
spawn(_run_dist, world_size, parallel=parallel)
|
||||
def test_mixtral_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
||||
test_mixtral_moe_layer(4)
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
import copy
|
||||
|
||||
import torch
|
||||
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_local_experts = num_local_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
|
||||
def test_moe_layer():
|
||||
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
|
||||
mistral_moe = MixtralSparseMoeBlock(config).cuda()
|
||||
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
|
||||
|
||||
data = torch.randn(2, 8, 4).cuda()
|
||||
mistral_output = mistral_moe(data)[0]
|
||||
colossal_output = colossal_moe(data)[0]
|
||||
assert torch.allclose(
|
||||
mistral_output, colossal_output
|
||||
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_moe_layer()
|
|
@ -2,22 +2,18 @@ import argparse
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoEHybridParallelCheckpointIO
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
|
||||
from colossal_moe.utils import load_checkpoint, move_to_cuda, save_checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
from transformers.models.mixtral import MixtralForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||
from colossalai.moe.layers import apply_load_balance
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -153,45 +149,27 @@ def main():
|
|||
coordinator = DistCoordinator()
|
||||
|
||||
# Set plugin
|
||||
booster_kwargs = {}
|
||||
hybrid_dict = {
|
||||
"tp_size": 1,
|
||||
"custom_policy": MixtralForCausalLMPolicy(),
|
||||
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||
"enable_jit_fused": args.use_kernel,
|
||||
"precision": args.precision,
|
||||
"zero_stage": args.zero_stage,
|
||||
"checkpoint_io": MixtralMoECheckpointIO,
|
||||
}
|
||||
mgr_dict = {}
|
||||
if args.plugin == "hybrid":
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
tp_size=1,
|
||||
pp_size=args.pp_size,
|
||||
ep_size=args.ep_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
**hybrid_dict,
|
||||
)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=args.dp_size,
|
||||
fixed_ep_size=args.ep_size,
|
||||
fixed_pp_size=args.pp_size,
|
||||
**mgr_dict,
|
||||
custom_policy=MixtralForCausalLMPolicy(),
|
||||
enable_fused_normalization=args.use_layernorm_kernel,
|
||||
enable_jit_fused=args.use_kernel,
|
||||
precision=args.precision,
|
||||
zero_stage=args.zero_stage,
|
||||
checkpoint_io=MixtralMoEHybridParallelCheckpointIO,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||
|
||||
# Build Mixtral model
|
||||
config = MixtralConfig.from_pretrained(args.model_name)
|
||||
config.use_cache = False
|
||||
config.num_local_experts = 1
|
||||
model = MixtralForCausalLM(config)
|
||||
model.num_experts = 8
|
||||
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||
model = model.to(get_current_device())
|
||||
replace_moe_layer(model, enable_kernel=args.use_kernel)
|
||||
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||
model = MixtralForCausalLM.from_pretrained(args.model_name)
|
||||
coordinator.print_on_master(f"Finish init model")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
@ -224,7 +202,7 @@ def main():
|
|||
)
|
||||
|
||||
# Set booster
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
booster = Booster(plugin=plugin)
|
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
|
@ -236,10 +214,7 @@ def main():
|
|||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Load ckpt
|
||||
if args.load_checkpoint is None:
|
||||
load_model(args.model_name, model, booster, optimizer)
|
||||
coordinator.print_on_master(f"Finish load checkpoint")
|
||||
else:
|
||||
if args.load_checkpoint is not None:
|
||||
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
|
||||
coordinator.print_on_master(f"Finish load optimizer")
|
||||
|
||||
|
@ -286,13 +261,13 @@ def main():
|
|||
optimizer.zero_grad()
|
||||
|
||||
# Apply load balance
|
||||
if (
|
||||
args.load_balance
|
||||
and args.load_balance_interval > 0
|
||||
and (step + 1) % args.load_balance_interval == 0
|
||||
):
|
||||
coordinator.print_on_master(f"Apply load balance")
|
||||
apply_load_balance(model, optimizer)
|
||||
# if (
|
||||
# args.load_balance
|
||||
# and args.load_balance_interval > 0
|
||||
# and (step + 1) % args.load_balance_interval == 0
|
||||
# ):
|
||||
# coordinator.print_on_master(f"Apply load balance")
|
||||
# apply_load_balance(model, optimizer)
|
||||
# save ckeckpoint
|
||||
if (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
|
|
|
@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.moe import MoECheckpintIO
|
||||
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
|
||||
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig
|
||||
|
@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
self,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
ep_size: int,
|
||||
extra_dp_size: int = 1,
|
||||
precision: str = "fp16",
|
||||
zero_stage: int = 0,
|
||||
|
@ -189,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
if enable_sequence_parallelism:
|
||||
assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size * ep_size) == 0
|
||||
), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size} * ep_size {ep_size}"
|
||||
self.real_dp_size = dist.get_world_size() // (tp_size * pp_size * ep_size)
|
||||
MOE_MANAGER.setup(
|
||||
parallel="EP",
|
||||
mode="fixed",
|
||||
fixed_dp_size=self.real_dp_size,
|
||||
fixed_ep_size=ep_size,
|
||||
fixed_pp_size=pp_size,
|
||||
use_ep_inside=use_ep_inside,
|
||||
)
|
||||
self.tp_size = tp_size
|
||||
self.pp_size = pp_size
|
||||
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_info = MOE_MANAGER.get_info(0)[1]
|
||||
self.precision = precision
|
||||
self.zero_stage = zero_stage
|
||||
self.cpu_offload = cpu_offload
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
from .utils import has_index_file
|
||||
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file
|
||||
|
||||
__all__ = ["CheckpointIO"]
|
||||
|
||||
|
@ -90,7 +90,15 @@ class CheckpointIO(ABC):
|
|||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
path = Path(checkpoint, SAFE_WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
path = Path(checkpoint, WEIGHTS_NAME)
|
||||
if path.is_file():
|
||||
self.load_unsharded_model(model, str(path), strict)
|
||||
else:
|
||||
self.load_unsharded_model(model, checkpoint, strict)
|
||||
|
||||
return origin_model
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -329,3 +329,68 @@ class MoeOutGradScaler(torch.autograd.Function):
|
|||
if ctx.ep_size != 1:
|
||||
grad = grad / ctx.ep_size
|
||||
return grad, None
|
||||
|
||||
|
||||
def _all_to_all(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
outputs_shape = list(inputs.shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
return outputs, handle
|
||||
|
||||
|
||||
class AllToAllUneven(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
inputs,
|
||||
input_split_sizes=None,
|
||||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
outputs: Tensor
|
||||
handle: Optional[Work], if overlap is True
|
||||
"""
|
||||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
return (
|
||||
_all_to_all(grad_outputs[0], ctx.output_split_sizes, ctx.input_split_sizes, ctx.group, False)[0],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def all_to_all_uneven(
|
||||
inputs: torch.Tensor,
|
||||
input_split_sizes: Optional[List[int]] = None,
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
):
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
|
|
|
@ -26,3 +26,5 @@ class MoeParallelInfo:
|
|||
self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
|
||||
self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
|
||||
self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
|
||||
self.ep_rank = self.pg.coordinate(self.ep_axis)
|
||||
self.dp_rank = self.pg.coordinate(self.dp_axis)
|
||||
|
|
|
@ -666,10 +666,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
def sync_moe_master_param(self):
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
@ -915,9 +911,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
|
||||
else:
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||
master_moe_param.copy_(working_moe_param)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.master_to_working_param
|
||||
return {**self._param_store.master_to_working_param, **self.moe_master_to_working_map}
|
||||
|
|
Loading…
Reference in New Issue