mirror of https://github.com/hpcaitech/ColossalAI
[moe] init mixtral impl
parent
c53ddda88f
commit
7d8e0338a4
Binary file not shown.
|
@ -0,0 +1,205 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||||
|
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||||
|
from colossalai.moe import MoECheckpintIO
|
||||||
|
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralMoECheckpointIO(MoECheckpintIO):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
|
||||||
|
"""
|
||||||
|
model_param_dict = dict(model.named_parameters())
|
||||||
|
for name, param in list(state_dict.items()):
|
||||||
|
if ".gate.weight" in name:
|
||||||
|
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
|
||||||
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
elif ".experts." in name:
|
||||||
|
# if is moe tensor
|
||||||
|
# in our moe module, expert is cat as one tensor
|
||||||
|
# but mixtral's experts is not cat
|
||||||
|
# we will insert the loaded expert into the position of cat tensor
|
||||||
|
|
||||||
|
# get model param
|
||||||
|
str_idx = name.index(".experts.")
|
||||||
|
expert_idx = int(name.split(".")[-3])
|
||||||
|
if ".w1." in name:
|
||||||
|
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
|
||||||
|
elif ".w2." in name:
|
||||||
|
model_param_name = name.replace(name[str_idx:], ".experts.wo")
|
||||||
|
elif ".w3." in name:
|
||||||
|
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
|
||||||
|
model_param_name = "module." + model_param_name
|
||||||
|
# skip for pipeline
|
||||||
|
if model_param_name not in model_param_dict:
|
||||||
|
continue
|
||||||
|
model_param = model_param_dict[model_param_name]
|
||||||
|
assert is_moe_tensor(model_param)
|
||||||
|
# get expert range
|
||||||
|
ep_rank = get_ep_rank(model_param)
|
||||||
|
ep_size = get_ep_size(model_param)
|
||||||
|
expert_num = 8 // ep_size
|
||||||
|
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
|
||||||
|
# insert new param
|
||||||
|
if expert_idx in expert_range:
|
||||||
|
new_param = model_param
|
||||||
|
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
|
||||||
|
state_dict[model_param_name] = new_param
|
||||||
|
state_dict.pop(name)
|
||||||
|
else:
|
||||||
|
new_name = "module." + name
|
||||||
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||||
|
"""
|
||||||
|
Load sharded model with the given path to index file of checkpoint folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to be loaded.
|
||||||
|
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||||
|
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||||
|
This argument should be manually set to False since params on same device might be stored in different files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check whether the checkpoint uses safetensors.
|
||||||
|
use_safetensors = False
|
||||||
|
if "safetensors" in checkpoint_index_file.name:
|
||||||
|
use_safetensors = True
|
||||||
|
|
||||||
|
if use_safetensors and not is_safetensors_available():
|
||||||
|
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||||
|
|
||||||
|
# Read checkpoint index file.
|
||||||
|
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||||
|
ckpt_root_path = ckpt_index_file.root_path
|
||||||
|
weight_map = ckpt_index_file.weight_map
|
||||||
|
strict = False
|
||||||
|
|
||||||
|
# Load params & buffers to model.
|
||||||
|
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||||
|
loaded_file = set()
|
||||||
|
|
||||||
|
def _load(name: str):
|
||||||
|
if name not in weight_map:
|
||||||
|
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
|
||||||
|
filename = weight_map[name]
|
||||||
|
|
||||||
|
# If this param/buffer has been loaded before, directly return.
|
||||||
|
if filename in loaded_file:
|
||||||
|
return
|
||||||
|
|
||||||
|
file_path = os.path.join(ckpt_root_path, filename)
|
||||||
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||||
|
state_dict = self.pre_load_model(model, state_dict)
|
||||||
|
missing_keys = []
|
||||||
|
|
||||||
|
load_state_dict_into_model(
|
||||||
|
model,
|
||||||
|
state_dict,
|
||||||
|
missing_keys=missing_keys,
|
||||||
|
strict=strict,
|
||||||
|
load_sub_module=True,
|
||||||
|
)
|
||||||
|
loaded_file.add(filename)
|
||||||
|
|
||||||
|
# Load parameters.
|
||||||
|
for name, _ in model.named_parameters():
|
||||||
|
name = name.replace("module.", "")
|
||||||
|
name = name.replace(".gate_weight", ".gate.weight")
|
||||||
|
if ".experts.wi_gate" in name:
|
||||||
|
for i in range(8):
|
||||||
|
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||||
|
_load(new_name)
|
||||||
|
elif ".experts.wi_up" in name:
|
||||||
|
for i in range(8):
|
||||||
|
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||||
|
_load(new_name)
|
||||||
|
elif ".experts.wo" in name:
|
||||||
|
for i in range(8):
|
||||||
|
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||||
|
_load(new_name)
|
||||||
|
else:
|
||||||
|
_load(name)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def pre_save_model(self, model: nn.Module) -> dict:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
for name, param in list(model.named_parameters()):
|
||||||
|
if ".gate_weight" in name:
|
||||||
|
new_name = name.replace(".gate_weight", ".gate.weight")
|
||||||
|
state_dict[new_name] = state_dict.pop(name).cpu()
|
||||||
|
elif ".experts." in name:
|
||||||
|
ep_group = get_ep_group(param)
|
||||||
|
ep_rank = get_ep_rank(param)
|
||||||
|
ep_size = get_ep_size(param)
|
||||||
|
dp_rank = get_dp_rank(param)
|
||||||
|
|
||||||
|
if dp_rank == 0:
|
||||||
|
param = param.data.cuda()
|
||||||
|
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
|
||||||
|
# gather param from every ep rank
|
||||||
|
dist.all_gather(all_param, param, group=ep_group)
|
||||||
|
if ep_rank == 0:
|
||||||
|
all_param = torch.cat(all_param, dim=0)
|
||||||
|
assert all_param.shape[0] == 8
|
||||||
|
for i in range(8):
|
||||||
|
if ".wi_gate" in name:
|
||||||
|
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
|
||||||
|
elif ".wi_up" in name:
|
||||||
|
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
|
||||||
|
elif ".wo" in name:
|
||||||
|
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
|
||||||
|
new_name = new_name.replace("module.", "")
|
||||||
|
new_param = all_param[i].transpose(-1, -2)
|
||||||
|
state_dict[new_name] = new_param.cpu()
|
||||||
|
state_dict.pop(name)
|
||||||
|
else:
|
||||||
|
state_dict[name] = param.cpu()
|
||||||
|
|
||||||
|
for name, param in list(state_dict.items()):
|
||||||
|
new_name = name.replace("module.", "")
|
||||||
|
state_dict[new_name] = state_dict.pop(name)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if self.pp_size > 1:
|
||||||
|
if self.dp_rank == 0:
|
||||||
|
# gather state_dict from every pp rank
|
||||||
|
# because ckpt is large, we split it into 10 parts
|
||||||
|
# and gather them one by one
|
||||||
|
new_state_dict = {}
|
||||||
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
gap_key_num = min(30, len(state_dict_keys))
|
||||||
|
gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num
|
||||||
|
for i in range(gap_key_num):
|
||||||
|
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
|
||||||
|
cur_state_dict = {}
|
||||||
|
for k in cur_keys:
|
||||||
|
cur_state_dict[k] = state_dict[k]
|
||||||
|
out = [None for _ in range(self.pp_size)]
|
||||||
|
dist.all_gather_object(out, cur_state_dict, group=self.pp_group)
|
||||||
|
if self.pp_rank == 0:
|
||||||
|
for o in out:
|
||||||
|
for k, v in o.items():
|
||||||
|
new_state_dict[k] = v.cpu()
|
||||||
|
state_dict = new_state_dict
|
||||||
|
dist.barrier()
|
||||||
|
return state_dict
|
|
@ -0,0 +1,80 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.moe import SparseMLP
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralSparseMLP:
|
||||||
|
r"""
|
||||||
|
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FusedLayerNorm is not implemented as a physical class. "
|
||||||
|
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module:
|
||||||
|
r"""
|
||||||
|
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
|
||||||
|
and optionally marking parameters for gradient aggregation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
|
||||||
|
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the provided module is not an instance of nn.LayerNorm.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
LazyInitContext.materialize(module)
|
||||||
|
|
||||||
|
# get the attributes of the module
|
||||||
|
moe_kwargs = dict(
|
||||||
|
num_experts=8,
|
||||||
|
hidden_size=module.hidden_dim,
|
||||||
|
intermediate_size=module.ffn_dim,
|
||||||
|
router_top_k=module.top_k,
|
||||||
|
router_norm=True,
|
||||||
|
router_loss=False,
|
||||||
|
# router_capacity_factor_train=
|
||||||
|
# router_capacity_factor_eval=
|
||||||
|
mlp_activation="silu",
|
||||||
|
mlp_gated=True,
|
||||||
|
# enable_load_balance=
|
||||||
|
# load_balance_tolerance=
|
||||||
|
# load_balance_beam_width=
|
||||||
|
# load_balance_group_swap_factor=
|
||||||
|
enable_kernel=enable_kernel,
|
||||||
|
# enable_comm_overlap=
|
||||||
|
# enable_hierarchical_comm=
|
||||||
|
return_gate_logits=True,
|
||||||
|
)
|
||||||
|
dtype = module.gate.weight.dtype
|
||||||
|
device = module.gate.weight.device
|
||||||
|
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
|
||||||
|
|
||||||
|
return sparse_mlp
|
||||||
|
|
||||||
|
|
||||||
|
def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Reverse the replace layer operation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): The object of layer to shard
|
||||||
|
"""
|
||||||
|
if isinstance(model, MixtralDecoderLayer):
|
||||||
|
model.block_sparse_moe = MixtralSparseMLP.from_native_module(
|
||||||
|
model.block_sparse_moe, enable_kernel=enable_kernel
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for _, child in model.named_children():
|
||||||
|
replace_moe_layer(child, enable_kernel)
|
|
@ -0,0 +1,543 @@
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import CrossEntropyLoss, Module
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
|
MixtralDecoderLayer,
|
||||||
|
MixtralForCausalLM,
|
||||||
|
MixtralModel,
|
||||||
|
MoeCausalLMOutputWithPast,
|
||||||
|
_prepare_4d_causal_attention_mask,
|
||||||
|
load_balancing_loss_func,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||||
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralPolicy(Policy):
|
||||||
|
def config_sanity_check(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def preprocess(self):
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# Resize embedding
|
||||||
|
vocab_size = self.model.config.vocab_size
|
||||||
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
|
|
||||||
|
if vocab_size % world_size != 0:
|
||||||
|
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||||
|
self.model.resize_token_embeddings(new_vocab_size)
|
||||||
|
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
|
||||||
|
|
||||||
|
# optimization configuration
|
||||||
|
if self.shard_config.enable_fused_normalization:
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="input_layernorm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
|
),
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="post_attention_layernorm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
policy=policy,
|
||||||
|
target_key=MixtralDecoderLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="norm",
|
||||||
|
target_module=FusedRMSNorm,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key=MixtralModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.shard_config.enable_flash_attention:
|
||||||
|
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
|
to customized forward method, and add this changing to policy."""
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
if self.model.__class__.__name__ == "MixtralModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description=method_replacement, policy=policy, target_key=model_cls
|
||||||
|
)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
assert self.pipeline_stage_manager is not None
|
||||||
|
|
||||||
|
if self.model.__class__.__name__ == "MixtralModel":
|
||||||
|
module = self.model
|
||||||
|
else:
|
||||||
|
module = self.model.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.embed_tokens)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.norm)
|
||||||
|
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralModelPolicy(MixtralPolicy):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=MixtralModel,
|
||||||
|
new_forward=MixtralPipelineForwards.mixtral_model_forward,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in llama model"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||||
|
def module_policy(self):
|
||||||
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
|
# add a new item for casual lm
|
||||||
|
new_item = {
|
||||||
|
MixtralForCausalLM: ModulePolicyDescription(
|
||||||
|
sub_module_replacement=[
|
||||||
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs=dict(gather_output=True),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
self.set_pipeline_forward(
|
||||||
|
model_cls=MixtralForCausalLM,
|
||||||
|
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
|
||||||
|
policy=policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = super().get_held_layers()
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(self.model.lm_head)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
llama_model = self.model.model
|
||||||
|
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||||
|
if (
|
||||||
|
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||||
|
and self.pipeline_stage_manager.num_stages > 1
|
||||||
|
):
|
||||||
|
# tie weights
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
0: llama_model.embed_tokens.weight,
|
||||||
|
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralPipelineForwards:
|
||||||
|
"""
|
||||||
|
This class serves as a micro library for forward function substitution of Llama models
|
||||||
|
under pipeline setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mixtral_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
past_router_logits: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||||
|
|
||||||
|
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
hidden_states,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
all_router_logits = () if output_router_logits else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||||
|
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
None,
|
||||||
|
output_attentions,
|
||||||
|
output_router_logits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
output_router_logits,
|
||||||
|
use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
if output_router_logits:
|
||||||
|
all_router_logits += (layer_outputs[-1],)
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
|
if output_router_logits and past_router_logits is not None:
|
||||||
|
all_router_logits = past_router_logits + all_router_logits
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
# always return dict for imediate stage
|
||||||
|
return {
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"past_router_logits": all_router_logits,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mixtral_for_causal_lm_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = True,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
past_router_logits: Optional[torch.FloatTensor] = None,
|
||||||
|
stage_index: Optional[List[int]] = None,
|
||||||
|
shard_config: ShardConfig = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||||
|
|
||||||
|
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
|
output_hidden_states = False
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = MixtralPipelineForwards.mixtral_model_forward(
|
||||||
|
self.model,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
return_dict=return_dict,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
past_router_logits=past_router_logits,
|
||||||
|
)
|
||||||
|
past_key_values = None
|
||||||
|
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
if output_router_logits:
|
||||||
|
output = (aux_loss,) + output
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=None,
|
||||||
|
hidden_states=outputs[0],
|
||||||
|
attentions=None,
|
||||||
|
router_logits=outputs[-1],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out = {}
|
||||||
|
hidden_states = outputs.get("hidden_states")
|
||||||
|
out["hidden_states"] = hidden_states
|
||||||
|
if output_router_logits:
|
||||||
|
out["past_router_logits"] = outputs["past_router_logits"]
|
||||||
|
return out
|
|
@ -0,0 +1,102 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_cuda(batch, device):
|
||||||
|
return {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
|
||||||
|
# pytorch ckpt
|
||||||
|
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
|
||||||
|
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
|
||||||
|
# saved ckpt
|
||||||
|
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
|
||||||
|
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
||||||
|
# download
|
||||||
|
else:
|
||||||
|
ckpt_path = snapshot_download(ckpt_path)
|
||||||
|
booster.load_model(model, ckpt_path)
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer.sync_moe_master_param()
|
||||||
|
optimizer.update_master_params(model)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load file in JSON format
|
||||||
|
"""
|
||||||
|
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||||
|
return json.load(fp)
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||||
|
"""
|
||||||
|
Save as JSON format
|
||||||
|
"""
|
||||||
|
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||||
|
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
save_dir: Union[str, os.PathLike],
|
||||||
|
booster: Booster,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
lr_scheduler: _LRScheduler,
|
||||||
|
epoch: int,
|
||||||
|
step: int,
|
||||||
|
batch_size: int,
|
||||||
|
coordinator: DistCoordinator,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||||
|
"""
|
||||||
|
|
||||||
|
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||||
|
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||||
|
|
||||||
|
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||||
|
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||||
|
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||||
|
running_states = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": step,
|
||||||
|
"sample_start_index": step * batch_size,
|
||||||
|
}
|
||||||
|
if coordinator.is_master():
|
||||||
|
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
load_dir: Union[str, os.PathLike],
|
||||||
|
booster: Booster,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
lr_scheduler: _LRScheduler,
|
||||||
|
) -> Tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Update booster params states.
|
||||||
|
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
|
||||||
|
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||||
|
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||||
|
|
||||||
|
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||||
|
return (
|
||||||
|
running_states["epoch"],
|
||||||
|
running_states["step"],
|
||||||
|
running_states["sample_start_index"],
|
||||||
|
)
|
|
@ -0,0 +1,138 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from colossal_moe.utils import load_model
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.moe import MOE_MANAGER
|
||||||
|
from colossalai.moe.utils import skip_init
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
# basic settings
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="hybrid",
|
||||||
|
choices=["ep"],
|
||||||
|
help="Parallel methos.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./outputs",
|
||||||
|
help="The path of your saved model after finetuning.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
default="bf16",
|
||||||
|
choices=["fp32", "bf16", "fp16"],
|
||||||
|
help="The mixed precision training.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||||
|
|
||||||
|
# kernel
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_kernel",
|
||||||
|
action="store_true",
|
||||||
|
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_layernorm_kernel",
|
||||||
|
action="store_true",
|
||||||
|
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Launch ColossalAI
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# Set plugin
|
||||||
|
booster_kwargs = {}
|
||||||
|
hybrid_dict = {
|
||||||
|
"tp_size": 1,
|
||||||
|
"custom_policy": MixtralForCausalLMPolicy(),
|
||||||
|
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||||
|
"enable_jit_fused": args.use_kernel,
|
||||||
|
"precision": args.precision,
|
||||||
|
"checkpoint_io": MixtralMoECheckpointIO,
|
||||||
|
"zero_stage": 1,
|
||||||
|
}
|
||||||
|
mgr_dict = {}
|
||||||
|
if args.plugin == "ep":
|
||||||
|
dp_size = dist.get_world_size()
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
pp_size=1,
|
||||||
|
**hybrid_dict,
|
||||||
|
)
|
||||||
|
MOE_MANAGER.setup(
|
||||||
|
parallel="EP",
|
||||||
|
max_ep_size=dp_size,
|
||||||
|
**mgr_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||||
|
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||||
|
|
||||||
|
# Build mixtral model
|
||||||
|
config = MixtralConfig.from_pretrained(args.model_name)
|
||||||
|
config.num_local_experts = 1 # dont change this. it will not affect model
|
||||||
|
with skip_init():
|
||||||
|
model = MixtralForCausalLM(config)
|
||||||
|
model.num_experts = 8
|
||||||
|
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||||
|
model = model.to(get_current_device())
|
||||||
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
|
# Replace moe
|
||||||
|
with skip_init():
|
||||||
|
replace_moe_layer(model)
|
||||||
|
model.eval()
|
||||||
|
coordinator.print_on_master(f"Finish replace moe module")
|
||||||
|
|
||||||
|
# Prepare tokenizer and dataloader
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, _, _, _, _ = booster.boost(model=model)
|
||||||
|
coordinator.print_on_master(f"Finish init booster")
|
||||||
|
|
||||||
|
# load ckpt
|
||||||
|
load_model(args.model_name, model, booster)
|
||||||
|
coordinator.print_on_master(f"Finish load ckpt")
|
||||||
|
|
||||||
|
text = ["Hello my name is", "1+1=?"]
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
|
||||||
|
outputs = model.module.generate(**inputs, max_new_tokens=20)
|
||||||
|
outputs = tokenizer.batch_decode(outputs)[0]
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,7 @@
|
||||||
|
NUM_GPU=2
|
||||||
|
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||||
|
|
||||||
|
# ep
|
||||||
|
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
|
||||||
|
--model_name $MODEL \
|
||||||
|
--plugin "ep" \
|
|
@ -0,0 +1,5 @@
|
||||||
|
colossalai >= 0.3.3
|
||||||
|
torch >= 1.8.1
|
||||||
|
transformers == 4.36.0
|
||||||
|
sentencepiece
|
||||||
|
datasets
|
|
@ -0,0 +1,43 @@
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_requirements(path):
|
||||||
|
with open(path, "r") as fd:
|
||||||
|
return [r.strip() for r in fd.readlines()]
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_readme():
|
||||||
|
with open("README.md", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_version():
|
||||||
|
with open("version.txt", "r") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="colossal_moe",
|
||||||
|
version=fetch_version(),
|
||||||
|
packages=find_packages(
|
||||||
|
exclude=(
|
||||||
|
"tests",
|
||||||
|
"benchmarks",
|
||||||
|
"*.egg-info",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
description="Colossal-AI MoE",
|
||||||
|
long_description=fetch_readme(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
license="Apache Software License 2.0",
|
||||||
|
url="https://github.com/hpcaitech",
|
||||||
|
install_requires=fetch_requirements("requirements.txt"),
|
||||||
|
python_requires=">=3.6",
|
||||||
|
classifiers=[
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Environment :: GPU :: NVIDIA CUDA",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: System :: Distributed Computing",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,185 @@
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
|
||||||
|
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": input_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_fwd_bwd(
|
||||||
|
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
|
||||||
|
):
|
||||||
|
model.train()
|
||||||
|
if pipeline:
|
||||||
|
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
|
||||||
|
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
|
||||||
|
y = booster.execute_pipeline(
|
||||||
|
train_dataloader_iter,
|
||||||
|
model,
|
||||||
|
lambda x, y: x.loss,
|
||||||
|
optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True,
|
||||||
|
)
|
||||||
|
# Backward and optimize
|
||||||
|
if is_pp_last_stage:
|
||||||
|
loss = y["loss"]
|
||||||
|
else:
|
||||||
|
if criterion:
|
||||||
|
y = model(data).logits
|
||||||
|
loss = criterion(y)
|
||||||
|
else:
|
||||||
|
loss = model(data, label)
|
||||||
|
loss = loss.float()
|
||||||
|
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
config = MixtralConfig(
|
||||||
|
vocab_size=300,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=16,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
dropout_rate=0.0,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(parallel):
|
||||||
|
config = get_config()
|
||||||
|
model = MixtralForCausalLM(config).to(torch.bfloat16)
|
||||||
|
replace_moe_layer(model)
|
||||||
|
optim = torch.optim.Adam(model.parameters())
|
||||||
|
args = dict(
|
||||||
|
precision="bf16",
|
||||||
|
tp_size=1,
|
||||||
|
zero_stage=1,
|
||||||
|
custom_policy=MixtralForCausalLMPolicy(),
|
||||||
|
checkpoint_io=MixtralMoECheckpointIO,
|
||||||
|
)
|
||||||
|
if parallel == "ep":
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
pp_size=1,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
elif parallel == "hybrid":
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
pp_size=2,
|
||||||
|
microbatch_size=1,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
|
||||||
|
return model, booster, optim
|
||||||
|
|
||||||
|
|
||||||
|
def _test_moe_checkpoint(parallel):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
if os.path.exists("./tmp_ckpt1"):
|
||||||
|
shutil.rmtree("./tmp_ckpt1")
|
||||||
|
if os.path.exists("./tmp_ckpt2"):
|
||||||
|
shutil.rmtree("./tmp_ckpt2")
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if parallel == None:
|
||||||
|
MOE_MANAGER.setup(
|
||||||
|
parallel=None,
|
||||||
|
)
|
||||||
|
elif parallel == "ep":
|
||||||
|
MOE_MANAGER.setup(
|
||||||
|
parallel="EP",
|
||||||
|
)
|
||||||
|
elif parallel == "hybrid":
|
||||||
|
MOE_MANAGER.setup(
|
||||||
|
parallel="EP",
|
||||||
|
mode="fixed",
|
||||||
|
fixed_dp_size=1,
|
||||||
|
fixed_ep_size=2,
|
||||||
|
fixed_pp_size=2,
|
||||||
|
)
|
||||||
|
model1, booster1, optim1 = get_model(parallel)
|
||||||
|
model2, booster2, optim2 = get_model(parallel)
|
||||||
|
# param ckpt
|
||||||
|
# check not equal
|
||||||
|
try:
|
||||||
|
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||||
|
raise AssertionError("state_dict should not be equal")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
# shard
|
||||||
|
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
|
||||||
|
booster2.load_model(model2, "./tmp_ckpt1")
|
||||||
|
# check
|
||||||
|
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
|
||||||
|
|
||||||
|
# optim ckpt
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
data = torch.randint(0, 4, (2, 4)).cuda()
|
||||||
|
label = torch.randint(0, 4, (2,)).cuda()
|
||||||
|
if parallel == "hybrid":
|
||||||
|
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
|
||||||
|
optim1.step()
|
||||||
|
optim1.zero_grad()
|
||||||
|
# shard
|
||||||
|
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
|
||||||
|
dist.barrier()
|
||||||
|
booster2.load_optimizer(optim2, "./tmp_ckpt2")
|
||||||
|
# check
|
||||||
|
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
shutil.rmtree("./tmp_ckpt1")
|
||||||
|
shutil.rmtree("./tmp_ckpt2")
|
||||||
|
|
||||||
|
|
||||||
|
def _run_dist(rank, world_size, port, parallel):
|
||||||
|
colossalai.launch(
|
||||||
|
config=dict(),
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
host="localhost",
|
||||||
|
port=port,
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
_test_moe_checkpoint(parallel)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize("world_size", [4])
|
||||||
|
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_moe_checkpoint(world_size, parallel):
|
||||||
|
spawn(_run_dist, world_size, parallel=parallel)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_moe_checkpoint(world_size=4, parallel="hybrid")
|
|
@ -0,0 +1,31 @@
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_local_experts = num_local_experts
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
|
||||||
|
|
||||||
|
def test_moe_layer():
|
||||||
|
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
|
||||||
|
mistral_moe = MixtralSparseMoeBlock(config).cuda()
|
||||||
|
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
|
||||||
|
|
||||||
|
data = torch.randn(2, 8, 4).cuda()
|
||||||
|
mistral_output = mistral_moe(data)[0]
|
||||||
|
colossal_output = colossal_moe(data)[0]
|
||||||
|
assert torch.allclose(
|
||||||
|
mistral_output, colossal_output
|
||||||
|
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_moe_layer()
|
|
@ -0,0 +1,320 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||||
|
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||||
|
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||||
|
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||||
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.moe import MOE_MANAGER, apply_load_balance
|
||||||
|
from colossalai.moe.layers import apply_load_balance
|
||||||
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_global_loss(loss, booster):
|
||||||
|
global_loss = loss.clone().detach()
|
||||||
|
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
|
||||||
|
global_loss.div_(booster.plugin.dp_size)
|
||||||
|
return global_loss
|
||||||
|
|
||||||
|
|
||||||
|
class RandomDataset(Dataset):
|
||||||
|
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.max_length = max_length
|
||||||
|
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||||
|
self.attention_mask = torch.ones_like(self.input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return {
|
||||||
|
"input_ids": self.input_ids[idx],
|
||||||
|
"attention_mask": self.attention_mask[idx],
|
||||||
|
"labels": self.input_ids[idx],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
# basic settings
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name",
|
||||||
|
type=str,
|
||||||
|
default="mistralai/Mixtral-8x7B-v0.1",
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||||
|
parser.add_argument(
|
||||||
|
"--plugin",
|
||||||
|
type=str,
|
||||||
|
default="hybrid",
|
||||||
|
choices=["hybrid"],
|
||||||
|
help="Parallel methods.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_path",
|
||||||
|
type=str,
|
||||||
|
default="./outputs",
|
||||||
|
help="The path of your saved model after finetuning.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Batch size (per dp group) for the training dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_interval",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help=" The interval (steps) of saving checkpoints.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
default="bf16",
|
||||||
|
choices=["fp32", "bf16", "fp16"],
|
||||||
|
help="The mixed precision training.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||||
|
|
||||||
|
# optim
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||||
|
|
||||||
|
# lr scheduler
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||||
|
|
||||||
|
# zero stage for all plugins
|
||||||
|
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
|
||||||
|
# hybrid plugin
|
||||||
|
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
|
||||||
|
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
|
||||||
|
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
|
||||||
|
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
|
||||||
|
|
||||||
|
# kernel
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_kernel",
|
||||||
|
action="store_true",
|
||||||
|
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_layernorm_kernel",
|
||||||
|
action="store_true",
|
||||||
|
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# load balance
|
||||||
|
parser.add_argument(
|
||||||
|
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
|
||||||
|
)
|
||||||
|
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
|
||||||
|
# communicate overlap
|
||||||
|
parser.add_argument(
|
||||||
|
"--comm_overlap",
|
||||||
|
action="store_true",
|
||||||
|
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
|
||||||
|
)
|
||||||
|
# hierarchical all-to-all
|
||||||
|
parser.add_argument(
|
||||||
|
"--hierarchical_alltoall",
|
||||||
|
action="store_true",
|
||||||
|
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Launch ColossalAI
|
||||||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||||||
|
coordinator = DistCoordinator()
|
||||||
|
|
||||||
|
# Set plugin
|
||||||
|
booster_kwargs = {}
|
||||||
|
hybrid_dict = {
|
||||||
|
"tp_size": 1,
|
||||||
|
"custom_policy": MixtralForCausalLMPolicy(),
|
||||||
|
"enable_fused_normalization": args.use_layernorm_kernel,
|
||||||
|
"enable_jit_fused": args.use_kernel,
|
||||||
|
"precision": args.precision,
|
||||||
|
"zero_stage": args.zero_stage,
|
||||||
|
"checkpoint_io": MixtralMoECheckpointIO,
|
||||||
|
}
|
||||||
|
mgr_dict = {}
|
||||||
|
if args.plugin == "hybrid":
|
||||||
|
plugin = MoeHybridParallelPlugin(
|
||||||
|
pp_size=args.pp_size,
|
||||||
|
microbatch_size=args.microbatch_size,
|
||||||
|
**hybrid_dict,
|
||||||
|
)
|
||||||
|
MOE_MANAGER.setup(
|
||||||
|
parallel="EP",
|
||||||
|
mode="fixed",
|
||||||
|
fixed_dp_size=args.dp_size,
|
||||||
|
fixed_ep_size=args.ep_size,
|
||||||
|
fixed_pp_size=args.pp_size,
|
||||||
|
**mgr_dict,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid plugin {args.plugin}")
|
||||||
|
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
|
||||||
|
|
||||||
|
# Build Mixtral model
|
||||||
|
config = MixtralConfig.from_pretrained(args.model_name)
|
||||||
|
config.use_cache = False
|
||||||
|
config.num_local_experts = 1
|
||||||
|
model = MixtralForCausalLM(config)
|
||||||
|
model.num_experts = 8
|
||||||
|
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
|
||||||
|
model = model.to(get_current_device())
|
||||||
|
replace_moe_layer(model, enable_kernel=args.use_kernel)
|
||||||
|
coordinator.print_on_master(f"Finish init model with config:\n{config}")
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Prepare tokenizer and dataloader
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||||
|
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
|
||||||
|
collate_fn = None
|
||||||
|
dataloader = plugin.prepare_dataloader(
|
||||||
|
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set optimizer
|
||||||
|
optimizer = HybridAdam(
|
||||||
|
model_params=model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
adamw_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set lr scheduler
|
||||||
|
lr_scheduler = CosineAnnealingWarmupLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
total_steps=args.num_epochs * len(dataloader),
|
||||||
|
warmup_steps=args.warmup_steps
|
||||||
|
if args.warmup_steps is not None
|
||||||
|
else int(args.num_epochs * len(dataloader) * 0.025),
|
||||||
|
eta_min=0.1 * args.lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set booster
|
||||||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
dataloader=dataloader,
|
||||||
|
)
|
||||||
|
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||||
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||||
|
coordinator.print_on_master(f"Finish init booster")
|
||||||
|
|
||||||
|
# Load ckpt
|
||||||
|
if args.load_checkpoint is None:
|
||||||
|
load_model(args.model_name, model, booster, optimizer)
|
||||||
|
coordinator.print_on_master(f"Finish load checkpoint")
|
||||||
|
else:
|
||||||
|
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
|
||||||
|
coordinator.print_on_master(f"Finish load optimizer")
|
||||||
|
|
||||||
|
# Start finetuning
|
||||||
|
coordinator.print_on_master(f"Start finetuning")
|
||||||
|
for epoch in range(args.num_epoch):
|
||||||
|
model.train()
|
||||||
|
train_dataloader_iter = iter(dataloader)
|
||||||
|
total_len = len(train_dataloader_iter)
|
||||||
|
with tqdm(
|
||||||
|
range(total_len),
|
||||||
|
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
|
||||||
|
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
|
||||||
|
) as pbar:
|
||||||
|
for step in pbar:
|
||||||
|
if use_pipeline:
|
||||||
|
# Forward pass
|
||||||
|
outputs = booster.execute_pipeline(
|
||||||
|
train_dataloader_iter,
|
||||||
|
model,
|
||||||
|
lambda x, y: x.loss,
|
||||||
|
optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
return_outputs=True,
|
||||||
|
)
|
||||||
|
# Backward and optimize
|
||||||
|
if is_pp_last_stage:
|
||||||
|
loss = outputs["loss"]
|
||||||
|
global_loss = get_global_loss(loss, booster)
|
||||||
|
if coordinator._local_rank == "0":
|
||||||
|
pbar.set_postfix({"Loss": global_loss.item()})
|
||||||
|
else:
|
||||||
|
# Forward pass
|
||||||
|
data = next(train_dataloader_iter)
|
||||||
|
data = move_to_cuda(data, torch.cuda.current_device())
|
||||||
|
outputs = model(**data)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
# Backward
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
pbar.set_postfix({"loss": loss.item()})
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Apply load balance
|
||||||
|
if (
|
||||||
|
args.load_balance
|
||||||
|
and args.load_balance_interval > 0
|
||||||
|
and (step + 1) % args.load_balance_interval == 0
|
||||||
|
):
|
||||||
|
coordinator.print_on_master(f"Apply load balance")
|
||||||
|
apply_load_balance(model, optimizer)
|
||||||
|
# save ckeckpoint
|
||||||
|
if (step + 1) % args.save_interval == 0:
|
||||||
|
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||||
|
save_checkpoint(
|
||||||
|
args.output_path,
|
||||||
|
booster,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
epoch,
|
||||||
|
step,
|
||||||
|
args.batch_size,
|
||||||
|
coordinator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# save checkpoint at the end of each epochs
|
||||||
|
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
|
||||||
|
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||||
|
|
||||||
|
# Finish training
|
||||||
|
coordinator.print_on_master(f"Finish training")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,19 @@
|
||||||
|
NUM_GPU=8
|
||||||
|
MODEL="mistralai/Mixtral-8x7B-v0.1"
|
||||||
|
SEQ_LENGTH=2048
|
||||||
|
BATCH_SIZE=1
|
||||||
|
LR=0.00001
|
||||||
|
|
||||||
|
# hybrid
|
||||||
|
# torchrun --standalone --nproc_per_node $NUM_GPU \
|
||||||
|
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
|
||||||
|
train.py \
|
||||||
|
--num_epoch 1 \
|
||||||
|
--model_name $MODEL \
|
||||||
|
--plugin "hybrid" \
|
||||||
|
--batch_size $BATCH_SIZE \
|
||||||
|
--lr $LR \
|
||||||
|
--zero_stage 1 \
|
||||||
|
--pp_size 2 \
|
||||||
|
--dp_size 1 \
|
||||||
|
--ep_size 8 \
|
|
@ -0,0 +1 @@
|
||||||
|
1.0.0
|
|
@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
overlap_communication: bool = True,
|
overlap_communication: bool = True,
|
||||||
use_ep_inside: bool = True,
|
use_ep_inside: bool = True,
|
||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
|
checkpoint_io: Optional[MoECheckpintIO] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
|
@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
self.enable_flash_attention = enable_flash_attention
|
self.enable_flash_attention = enable_flash_attention
|
||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||||
|
self.checkpoint_io = checkpoint_io
|
||||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||||
|
|
||||||
|
@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
def get_checkpoint_io(self) -> MoECheckpintIO:
|
||||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
if self.checkpoint_io is None:
|
||||||
|
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||||
|
else:
|
||||||
|
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||||
return self.checkpoint_io
|
return self.checkpoint_io
|
||||||
|
|
||||||
def configure(
|
def configure(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .checkpoint import MoECheckpintIO
|
from .checkpoint import MoECheckpintIO
|
||||||
from .experts import MLPExperts
|
from .experts import MLPExperts
|
||||||
from .layers import SparseMLP
|
from .layers import SparseMLP, apply_load_balance
|
||||||
|
from .manager import MOE_MANAGER
|
||||||
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
|
||||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
from .utils import NormalNoiseGenerator, UniformNoiseGenerator
|
||||||
|
|
||||||
|
@ -14,4 +15,6 @@ __all__ = [
|
||||||
"UniformNoiseGenerator",
|
"UniformNoiseGenerator",
|
||||||
"SparseMLP",
|
"SparseMLP",
|
||||||
"MoECheckpintIO",
|
"MoECheckpintIO",
|
||||||
|
"MOE_MANAGER",
|
||||||
|
"apply_load_balance",
|
||||||
]
|
]
|
||||||
|
|
|
@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
|
||||||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
torch.cuda.empty_cache()
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# ========================================================
|
# ========================================================
|
||||||
# Abstract methods for optimizer loading/saving implementation
|
# Abstract methods for optimizer loading/saving implementation
|
||||||
|
@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
|
|
||||||
def _get_param_id_from_optimizer_param(
|
def _get_param_id_from_optimizer_param(
|
||||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
|
||||||
):
|
):
|
||||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||||
working_param = master_to_working_map[id(param)]
|
working_param = master_to_working_map[id(param)]
|
||||||
|
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||||
|
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||||
else:
|
else:
|
||||||
working_param = param
|
working_param = param
|
||||||
return optimizer.param_info["param2id"][id(working_param)]
|
return optimizer.param_info["param2id"][id(working_param)]
|
||||||
|
@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
master_to_working_map = optimizer.get_master_to_working_map()
|
master_to_working_map = optimizer.get_master_to_working_map()
|
||||||
for pg in optimizer.optim.param_groups:
|
for pg in optimizer.optim.param_groups:
|
||||||
for param in pg["params"]:
|
for param in pg["params"]:
|
||||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||||
id_map[param_id] = param
|
id_map[param_id] = param
|
||||||
|
|
||||||
# Read checkpoint index file.
|
# Read checkpoint index file.
|
||||||
|
@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
new_pg = copy.deepcopy(saved_pg)
|
new_pg = copy.deepcopy(saved_pg)
|
||||||
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
|
||||||
updated_groups.append(new_pg)
|
updated_groups.append(new_pg)
|
||||||
# ep extra group
|
# ep param group
|
||||||
if MOE_MANAGER.parallel == "EP":
|
if len(optimizer.optim.param_groups) > len(saved_groups):
|
||||||
new_pg = copy.deepcopy(saved_pg)
|
new_pg = copy.deepcopy(saved_pg)
|
||||||
new_pg["params"] = optimizer.optim.param_groups[-1][
|
new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
|
||||||
"params"
|
|
||||||
] # Only keep the parameters kept by current pipeline stage.
|
|
||||||
for param in new_pg["params"]:
|
|
||||||
param.data = param.data.to(torch.float32)
|
|
||||||
updated_groups.append(new_pg)
|
updated_groups.append(new_pg)
|
||||||
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
optimizer.optim.__dict__.update({"param_groups": updated_groups})
|
||||||
|
|
||||||
|
@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
for param in pg["params"]:
|
for param in pg["params"]:
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
|
||||||
if param_id not in weight_map:
|
if param_id not in weight_map:
|
||||||
continue
|
continue
|
||||||
filename = weight_map[param_id]
|
filename = weight_map[param_id]
|
||||||
|
@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
|
|
||||||
file_path = os.path.join(ckpt_root_path, filename)
|
file_path = os.path.join(ckpt_root_path, filename)
|
||||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||||
|
|
||||||
|
# Then shard the loaded optimizer states if using tp/zero.
|
||||||
|
for pid, state in list(state_dict.items()):
|
||||||
|
if pid in id_map:
|
||||||
|
param = id_map[pid]
|
||||||
|
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||||
|
working_param = master_to_working_map[id(param)]
|
||||||
|
elif (
|
||||||
|
hasattr(optimizer, "moe_master_to_working_map")
|
||||||
|
and id(param) in optimizer.moe_master_to_working_map
|
||||||
|
):
|
||||||
|
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||||
|
else:
|
||||||
|
working_param = param
|
||||||
|
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||||
|
sharded_state = self.pre_load_optim(
|
||||||
|
state,
|
||||||
|
working_param,
|
||||||
|
current_shape=working_param.shape,
|
||||||
|
original_shape=original_shape,
|
||||||
|
device="cpu",
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
state_dict[pid] = sharded_state
|
||||||
|
|
||||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||||
loaded_file.add(filename)
|
loaded_file.add(filename)
|
||||||
|
|
||||||
# Then shard the loaded optimizer states if using tp/zero.
|
|
||||||
for param, state in optimizer.optim.state.items():
|
|
||||||
device = param.device
|
|
||||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
|
||||||
working_param = master_to_working_map[id(param)]
|
|
||||||
else:
|
|
||||||
working_param = param
|
|
||||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
|
||||||
sharded_state = self.pre_load_optim(
|
|
||||||
state,
|
|
||||||
param,
|
|
||||||
current_shape=working_param.shape,
|
|
||||||
original_shape=original_shape,
|
|
||||||
device=device,
|
|
||||||
inplace=True,
|
|
||||||
)
|
|
||||||
optimizer.optim.state[param] = sharded_state
|
|
||||||
|
|
||||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||||
if self.verbose and self.coordinator.is_master():
|
if self.verbose and self.coordinator.is_master():
|
||||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||||
|
@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
|
|
||||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||||
working_param = master_to_working_map[id(param)]
|
working_param = master_to_working_map[id(param)]
|
||||||
|
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||||
|
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||||
else:
|
else:
|
||||||
working_param = param
|
working_param = param
|
||||||
|
|
||||||
|
@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
prefix (str): Perfix of file to save
|
prefix (str): Perfix of file to save
|
||||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||||
"""
|
"""
|
||||||
|
torch.cuda.empty_cache()
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||||
if os.path.isfile(checkpoint):
|
if os.path.isfile(checkpoint):
|
||||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||||
|
@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
||||||
f"You can find where each parameters has been saved in the "
|
f"You can find where each parameters has been saved in the "
|
||||||
f"index located at {final_index_file_path}."
|
f"index located at {final_index_file_path}."
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
|
||||||
self.ep_size = 1
|
self.ep_size = 1
|
||||||
|
|
||||||
if gated:
|
if gated:
|
||||||
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
|
self.wi_gate = nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
|
||||||
|
)
|
||||||
|
)
|
||||||
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||||
else:
|
else:
|
||||||
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
|
||||||
|
|
|
@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
router_top_k: int = 1,
|
router_top_k: int = 1,
|
||||||
|
router_loss: bool = True,
|
||||||
|
router_norm: bool = False,
|
||||||
router_capacity_factor_train: float = 1.25,
|
router_capacity_factor_train: float = 1.25,
|
||||||
router_capacity_factor_eval: float = 2.0,
|
router_capacity_factor_eval: float = 2.0,
|
||||||
router_min_capacity: int = 4,
|
router_min_capacity: int = 4,
|
||||||
|
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
|
||||||
enable_kernel: bool = False,
|
enable_kernel: bool = False,
|
||||||
enable_comm_overlap: bool = False,
|
enable_comm_overlap: bool = False,
|
||||||
enable_hierarchical_comm: bool = False,
|
enable_hierarchical_comm: bool = False,
|
||||||
|
return_gate_logits: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.gated = mlp_gated
|
self.gated = mlp_gated
|
||||||
|
self.return_gate_logits = return_gate_logits
|
||||||
self.enable_kernel = enable_kernel
|
self.enable_kernel = enable_kernel
|
||||||
self.enable_comm_overlap = enable_comm_overlap
|
self.enable_comm_overlap = enable_comm_overlap
|
||||||
self.expert_parallel = MOE_MANAGER.get_parallel()
|
self.expert_parallel = MOE_MANAGER.get_parallel()
|
||||||
|
self.router_loss = router_loss
|
||||||
|
self.router_norm = router_norm
|
||||||
|
|
||||||
# moe router
|
# moe router
|
||||||
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
noisy_func = get_noise_generator(router_noisy_policy, num_experts)
|
||||||
|
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
|
||||||
tokens = inputs.reshape(-1, self.hidden_size)
|
tokens = inputs.reshape(-1, self.hidden_size)
|
||||||
|
|
||||||
# the data type of the inputs in the gating should be fp32
|
# the data type of the inputs in the gating should be fp32
|
||||||
fp32_input = tokens.to(torch.float)
|
gate_logits = F.linear(tokens, self.gate_weight)
|
||||||
fp32_weight = self.gate_weight.to(torch.float)
|
gate_output = gate_logits.to(torch.float)
|
||||||
gate_output = F.linear(fp32_input, fp32_weight)
|
|
||||||
|
|
||||||
# update expert load
|
# update expert load
|
||||||
if self.enable_load_balance == True:
|
if self.enable_load_balance == True:
|
||||||
|
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
|
||||||
|
|
||||||
# the result from the router
|
# the result from the router
|
||||||
used_capacity, *route_result_list = self.router(
|
used_capacity, *route_result_list = self.router(
|
||||||
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
|
inputs=gate_output,
|
||||||
|
use_kernel=self.enable_kernel,
|
||||||
|
ep_group=self.ep_group,
|
||||||
|
use_loss=self.router_loss,
|
||||||
|
use_norm=self.router_norm,
|
||||||
|
)
|
||||||
|
|
||||||
# dispatch_data: (num_experts, capacity, hidden_size)
|
# dispatch_data: (num_experts, capacity, hidden_size)
|
||||||
if self.enable_kernel:
|
if self.enable_kernel:
|
||||||
|
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
|
||||||
|
|
||||||
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
# expert_output: (num_groups, num_experts, capacity, hidden_size)
|
||||||
if self.expert_parallel == "EP":
|
if self.expert_parallel == "EP":
|
||||||
expert_output = self._ep_process(
|
expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||||
dispatch_data,
|
|
||||||
used_capacity,
|
|
||||||
overlap=self.enable_comm_overlap
|
|
||||||
)
|
|
||||||
elif self.expert_parallel == "TP":
|
elif self.expert_parallel == "TP":
|
||||||
expert_output = self._tp_process(
|
expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
|
||||||
dispatch_data,
|
|
||||||
used_capacity,
|
|
||||||
overlap=self.enable_comm_overlap
|
|
||||||
)
|
|
||||||
elif self.expert_parallel is None:
|
elif self.expert_parallel is None:
|
||||||
expert_output = self._local_process(dispatch_data)
|
expert_output = self._local_process(dispatch_data)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n"
|
raise NotImplementedError(
|
||||||
"Please use Experts build function.")
|
"This kind of communication has not been implemented yet.\n" "Please use Experts build function."
|
||||||
|
)
|
||||||
|
|
||||||
if self.enable_kernel:
|
if self.enable_kernel:
|
||||||
expert_output = expert_output.reshape(-1, self.hidden_size)
|
expert_output = expert_output.reshape(-1, self.hidden_size)
|
||||||
|
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
|
||||||
ans = torch.matmul(combine_weights, expert_output)
|
ans = torch.matmul(combine_weights, expert_output)
|
||||||
|
|
||||||
ans = ans.reshape(inputs.shape)
|
ans = ans.reshape(inputs.shape)
|
||||||
return ans
|
|
||||||
|
if self.return_gate_logits:
|
||||||
|
return ans, gate_logits
|
||||||
|
else:
|
||||||
|
return ans
|
||||||
|
|
||||||
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
|
||||||
expert_in = expert_in.unsqueeze(0)
|
expert_in = expert_in.unsqueeze(0)
|
||||||
|
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
|
||||||
return expert_out
|
return expert_out
|
||||||
|
|
||||||
def _ep_process(
|
def _ep_process(
|
||||||
self,
|
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||||
dispatch_data: torch.Tensor,
|
|
||||||
used_capacity: torch.Tensor,
|
|
||||||
overlap: bool = False
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Expert Parallel
|
Expert Parallel
|
||||||
|
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
if not overlap or dist.get_world_size(self.ep_group) == 1:
|
||||||
if self.ep_hierarchical_group is not None:
|
if self.ep_hierarchical_group is not None:
|
||||||
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
expert_input = HierarchicalAllToAll.apply(
|
||||||
|
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||||
|
)
|
||||||
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||||
expert_output = self.experts(expert_input)
|
expert_output = self.experts(expert_input)
|
||||||
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
|
expert_output = HierarchicalAllToAll.apply(
|
||||||
|
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
|
||||||
|
)
|
||||||
return expert_output
|
return expert_output
|
||||||
else:
|
else:
|
||||||
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
|
||||||
|
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
|
||||||
NUM_CHUNK = 4
|
NUM_CHUNK = 4
|
||||||
NUM_STAGES = 4
|
NUM_STAGES = 4
|
||||||
|
|
||||||
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
|
assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
|
||||||
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
|
||||||
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
|
||||||
dispatch_data = dispatch_data.reshape(*input_shape)
|
dispatch_data = dispatch_data.reshape(*input_shape)
|
||||||
|
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
|
||||||
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
for i in range(NUM_CHUNK + NUM_STAGES - 1):
|
||||||
if expert_out is not None:
|
if expert_out is not None:
|
||||||
expert_out.handle.wait()
|
expert_out.handle.wait()
|
||||||
output[:, :, offset:offset + chunk_size, :] = expert_out.data
|
output[:, :, offset : offset + chunk_size, :] = expert_out.data
|
||||||
offset += chunk_size
|
offset += chunk_size
|
||||||
expert_out = None
|
expert_out = None
|
||||||
|
|
||||||
# all2all last output
|
# all2all last output
|
||||||
if _expert_out is not None:
|
if _expert_out is not None:
|
||||||
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
|
expert_out = Capsule(
|
||||||
|
*AllToAll.apply(_expert_out.data, self.ep_group, True),
|
||||||
|
)
|
||||||
_expert_out = None
|
_expert_out = None
|
||||||
|
|
||||||
# all2all next input
|
# all2all next input
|
||||||
|
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _tp_process(
|
def _tp_process(
|
||||||
self,
|
self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
|
||||||
dispatch_data: torch.Tensor,
|
|
||||||
used_capacity: torch.Tensor,
|
|
||||||
overlap: bool = False
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
without overlap:
|
without overlap:
|
||||||
|
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
|
||||||
NUM_CHUNK = 4
|
NUM_CHUNK = 4
|
||||||
NUM_STAGES = 4
|
NUM_STAGES = 4
|
||||||
|
|
||||||
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
|
assert (
|
||||||
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
dispatch_data.shape[0] % NUM_CHUNK == 0
|
||||||
|
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
|
||||||
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
|
||||||
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
|
||||||
output = torch.empty_like(dispatch_data)
|
output = torch.empty_like(dispatch_data)
|
||||||
|
|
|
@ -150,7 +150,14 @@ class Top1Router(MoeRouter):
|
||||||
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
||||||
).rsample
|
).rsample
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
use_kernel: bool = False,
|
||||||
|
ep_group: Optional[ProcessGroup] = None,
|
||||||
|
use_loss: bool = False,
|
||||||
|
use_norm: bool = False,
|
||||||
|
) -> Tuple:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||||
|
@ -207,7 +214,7 @@ class Top1Router(MoeRouter):
|
||||||
weight = mask * probs.type_as(inputs)
|
weight = mask * probs.type_as(inputs)
|
||||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||||
sec_mask = combine_weights.bool()
|
sec_mask = combine_weights.bool()
|
||||||
return used_capacity, combine_weights, sec_mask
|
return used_capacity, combine_weights, sec_mask, probs
|
||||||
|
|
||||||
|
|
||||||
class Top2Router(MoeRouter):
|
class Top2Router(MoeRouter):
|
||||||
|
@ -240,7 +247,14 @@ class Top2Router(MoeRouter):
|
||||||
drop_tks=drop_tks,
|
drop_tks=drop_tks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
use_kernel: bool = False,
|
||||||
|
ep_group: Optional[ProcessGroup] = None,
|
||||||
|
use_norm: bool = False,
|
||||||
|
use_loss: bool = True,
|
||||||
|
) -> Tuple:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
|
||||||
|
@ -257,6 +271,10 @@ class Top2Router(MoeRouter):
|
||||||
|
|
||||||
assert inputs.dtype == torch.float
|
assert inputs.dtype == torch.float
|
||||||
probs = F.softmax(inputs, dim=-1)
|
probs = F.softmax(inputs, dim=-1)
|
||||||
|
if use_norm:
|
||||||
|
routing_weights, _ = torch.topk(probs, 2, dim=-1)
|
||||||
|
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
num_experts = probs.size(-1)
|
num_experts = probs.size(-1)
|
||||||
capacity = self.get_capacity(inputs.shape)
|
capacity = self.get_capacity(inputs.shape)
|
||||||
|
|
||||||
|
@ -270,10 +288,11 @@ class Top2Router(MoeRouter):
|
||||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||||
|
|
||||||
# calculate loss
|
# calculate loss
|
||||||
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
if use_loss:
|
||||||
self.set_aux_loss(probs, expert_indices, num_experts)
|
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
|
||||||
self.set_z_loss(inputs)
|
self.set_aux_loss(probs, expert_indices, num_experts)
|
||||||
self.pop_router_loss()
|
self.set_z_loss(inputs)
|
||||||
|
self.pop_router_loss()
|
||||||
|
|
||||||
if not self.training and not self.drop_tks and ep_group is not None:
|
if not self.training and not self.drop_tks and ep_group is not None:
|
||||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||||
|
|
|
@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
|
||||||
return torch.nn.GELU()
|
return torch.nn.GELU()
|
||||||
elif act == "swiglu":
|
elif act == "swiglu":
|
||||||
return SwiGLU
|
return SwiGLU
|
||||||
|
elif act == "silu":
|
||||||
|
return torch.nn.SiLU()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unsupported activation function")
|
raise NotImplementedError("Unsupported activation function")
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# because they have different parallel strategy
|
# because they have different parallel strategy
|
||||||
# so we need to store them separately in param_groups
|
# so we need to store them separately in param_groups
|
||||||
# instead of working_groups
|
# instead of working_groups
|
||||||
moe_params = list()
|
self.working_moe_params = list()
|
||||||
|
|
||||||
# iterate over the param group in the optimizer
|
# iterate over the param group in the optimizer
|
||||||
# partition these param groups for data parallel training
|
# partition these param groups for data parallel training
|
||||||
|
@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
if self.moe_extra_dp_pg is None:
|
if self.moe_extra_dp_pg is None:
|
||||||
# skip moe param
|
# skip moe param
|
||||||
if is_moe_tensor(param):
|
if is_moe_tensor(param):
|
||||||
moe_params.append(param)
|
self.working_moe_params.append(param)
|
||||||
continue
|
continue
|
||||||
group_params.append(param)
|
group_params.append(param)
|
||||||
|
|
||||||
|
@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# managed by this data parallel rank
|
# managed by this data parallel rank
|
||||||
param_group["params"] = master_param_current_rank
|
param_group["params"] = master_param_current_rank
|
||||||
|
|
||||||
# if there are moe params, store in additional group in optim
|
# if there are moe params, store in addtional group in optim
|
||||||
if len(moe_params) > 0:
|
if len(self.working_moe_params) > 0:
|
||||||
|
self._sync_master_param = False
|
||||||
param_group = dict()
|
param_group = dict()
|
||||||
|
# create fp32 master param
|
||||||
for key, value in self.optim.param_groups[0].items():
|
for key, value in self.optim.param_groups[0].items():
|
||||||
if key != "params":
|
if key != "params":
|
||||||
param_group[key] = value
|
param_group[key] = value
|
||||||
param_group["params"] = moe_params
|
self.master_moe_params = []
|
||||||
|
for param in self.working_moe_params:
|
||||||
|
self.master_moe_params.append(param.clone().to(torch.float32).detach())
|
||||||
|
# create mapping from master to working for optimizer io
|
||||||
|
self.moe_master_to_working_map = {}
|
||||||
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||||
|
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
|
||||||
|
# add to optim
|
||||||
|
param_group["params"] = self.master_moe_params
|
||||||
self.optim.param_groups.append(param_group)
|
self.optim.param_groups.append(param_group)
|
||||||
|
|
||||||
# initialize communication stream for
|
# initialize communication stream for
|
||||||
|
@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
# update the params in the optimizer
|
# update the params in the optimizer
|
||||||
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
|
||||||
|
|
||||||
|
# update param for moe ep
|
||||||
|
# move grad to master param and compute norm
|
||||||
|
if len(self.working_moe_params) > 0:
|
||||||
|
moe_grads = []
|
||||||
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||||
|
if master_moe_param.grad is not None:
|
||||||
|
raise RuntimeError("Moe param should not have grad here")
|
||||||
|
grad = working_moe_param.grad
|
||||||
|
# no need to copy fp32 grad if master_weights is False
|
||||||
|
if self._master_weights:
|
||||||
|
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
|
||||||
|
master_moe_param.grad = grad
|
||||||
|
working_moe_param.grad = None
|
||||||
|
moe_grads.append(grad)
|
||||||
|
grad_partition_groups.append(grad)
|
||||||
|
norm_group = self._compute_grad_norm(gradients=moe_grads)
|
||||||
|
norm_groups.append(norm_group)
|
||||||
|
self.optim.param_groups[-1]["params"] = self.master_moe_params
|
||||||
|
del moe_grads
|
||||||
|
|
||||||
# unscale and clip grads
|
# unscale and clip grads
|
||||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||||
|
|
||||||
# TODO: we should store master param for ep
|
|
||||||
if len(self.param_groups) > len(self._working_param_groups):
|
|
||||||
for param in self.param_groups[-1]["params"]:
|
|
||||||
param.data = param.data.to(torch.float32)
|
|
||||||
param.grad = param.grad.to(torch.float32)
|
|
||||||
|
|
||||||
# update the parameters
|
# update the parameters
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
|
||||||
# release the moe gradm
|
# release moe grad
|
||||||
if len(self.param_groups) > len(self._working_param_groups):
|
if len(self.working_moe_params) > 0:
|
||||||
for param in self.param_groups[-1]["params"]:
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||||
param.grad = None
|
master_moe_param.grad = None
|
||||||
param.data = param.data.to(self._dtype)
|
working_moe_param.data = (
|
||||||
|
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
|
||||||
|
)
|
||||||
|
|
||||||
# release the grad
|
# release the grad
|
||||||
grad_partition_groups = []
|
grad_partition_groups = []
|
||||||
|
@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||||
|
|
||||||
|
def sync_moe_master_param(self):
|
||||||
|
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
|
||||||
|
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
|
||||||
|
|
||||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||||
r"""
|
r"""
|
||||||
Compute and return the gradient norm for gradient clipping.
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
|
@ -1,13 +1,22 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||||
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
|
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
|
||||||
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
|
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
|
||||||
from colossalai.legacy.registry import GRADIENT_HANDLER
|
from colossalai.legacy.registry import GRADIENT_HANDLER
|
||||||
from colossalai.moe import SparseMLP
|
from colossalai.moe import SparseMLP
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.moe.utils import get_moe_epsize_param_dict
|
from colossalai.moe.utils import get_moe_epsize_param_dict
|
||||||
|
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
|
||||||
|
|
||||||
|
|
||||||
|
def delete_moe_info(model):
|
||||||
|
for _, param in model.named_parameters():
|
||||||
|
if hasattr(param, "moe_info"):
|
||||||
|
delattr(param, "moe_info")
|
||||||
|
|
||||||
|
|
||||||
class MoeModel(nn.Module):
|
class MoeModel(nn.Module):
|
||||||
|
@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
|
||||||
for i in range(world_size - 1):
|
for i in range(world_size - 1):
|
||||||
a = tensor_list[i]
|
a = tensor_list[i]
|
||||||
b = tensor_list[i + 1]
|
b = tensor_list[i + 1]
|
||||||
assert not torch.allclose(a, b), \
|
assert not torch.allclose(a, b), (
|
||||||
(f"expected tensors on rank {i} and {i + 1} not to be equal "
|
f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
|
||||||
f"but they are, {a} vs {b}")
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
||||||
|
model.train()
|
||||||
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||||
|
if criterion:
|
||||||
|
y = model(data)
|
||||||
|
loss = criterion(y, label)
|
||||||
|
else:
|
||||||
|
loss = model(data, label)
|
||||||
|
loss = loss.float()
|
||||||
|
|
||||||
|
if isinstance(model, LowLevelZeroModel):
|
||||||
|
optimizer.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
|
||||||
|
"""Sync the parameters of tp model from ep model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_model (MoeModule)
|
||||||
|
ep_model (MoeModule)
|
||||||
|
"""
|
||||||
|
for (local_name, local_param), (ep_name, ep_param) in zip(
|
||||||
|
local_model.named_parameters(), ep_model.named_parameters()
|
||||||
|
):
|
||||||
|
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
|
||||||
|
if "experts" not in local_name:
|
||||||
|
if assert_grad_flag:
|
||||||
|
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
|
||||||
|
assert torch.allclose(local_param.grad, ep_param.grad)
|
||||||
|
else:
|
||||||
|
local_param.data.copy_(ep_param.data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# gather param from ep model
|
||||||
|
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||||
|
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
|
||||||
|
all_param = torch.cat(param_list, dim=0)
|
||||||
|
if assert_grad_flag:
|
||||||
|
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
|
||||||
|
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
|
||||||
|
all_grad = torch.cat(grad_list, dim=0)
|
||||||
|
|
||||||
|
if assert_grad_flag:
|
||||||
|
assert torch.allclose(local_param, all_param)
|
||||||
|
assert torch.allclose(local_param.grad, all_grad)
|
||||||
|
else:
|
||||||
|
local_param.data.copy_(all_param.data)
|
||||||
|
|
||||||
|
|
||||||
|
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||||
|
rtol = None
|
||||||
|
atol = None
|
||||||
|
if dtype is torch.float16:
|
||||||
|
rtol = 5e-2
|
||||||
|
atol = 5e-4
|
||||||
|
elif dtype is torch.bfloat16:
|
||||||
|
rtol = 4e-3
|
||||||
|
atol = 4e-3
|
||||||
|
|
||||||
|
a = a.detach().to(dtype)
|
||||||
|
b = b.detach().to(dtype).to(a.device)
|
||||||
|
|
||||||
|
assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
|
@ -4,102 +4,75 @@ import torch
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
|
||||||
|
|
||||||
|
|
||||||
def split_ddp_grad(grad, world_size):
|
def run_zero_test(local_rank, stage=1):
|
||||||
with torch.no_grad():
|
|
||||||
grad = grad.clone().detach().flatten()
|
|
||||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
|
||||||
if padding_size > 0:
|
|
||||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
|
||||||
splited_grad = grad.split(grad.numel() // world_size)
|
|
||||||
return splited_grad
|
|
||||||
|
|
||||||
|
|
||||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
|
||||||
model.train()
|
|
||||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
|
||||||
if criterion:
|
|
||||||
y = model(data)
|
|
||||||
loss = criterion(y, label)
|
|
||||||
else:
|
|
||||||
loss = model(data, label)
|
|
||||||
loss = loss.float()
|
|
||||||
|
|
||||||
if isinstance(model, LowLevelZeroModel):
|
|
||||||
optimizer.backward(loss)
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def run_zero_test(local_rank, world_size, stage=1):
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
zero_model = MoeModel()
|
MOE_MANAGER.__init__()
|
||||||
optimizer = torch.optim.Adam(zero_model.parameters())
|
MOE_MANAGER.setup(parallel="EP")
|
||||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
moe_model = MoeModel().bfloat16()
|
||||||
booster = Booster(plugin=plugin)
|
moe_optimizer = torch.optim.Adam(moe_model.parameters())
|
||||||
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
|
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||||
|
moe_booster = Booster(plugin=moe_plugin)
|
||||||
|
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
|
||||||
|
|
||||||
torch_model = MoeModel()
|
MOE_MANAGER.__init__()
|
||||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
MOE_MANAGER.setup(parallel=None)
|
||||||
torch_param.data.copy_(zero_param.data)
|
zero_model = MoeModel().bfloat16()
|
||||||
torch_model = torch_model.cuda()
|
delete_moe_info(zero_model)
|
||||||
grad_handler = MoeGradientHandler(torch_model)
|
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
||||||
|
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||||
|
zero_booster = Booster(plugin=zero_plugin)
|
||||||
|
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
|
||||||
|
sync_local_from_ep(zero_model, moe_model)
|
||||||
|
|
||||||
# assert zero model
|
data = torch.randn(16, 4).bfloat16().cuda()
|
||||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
|
||||||
torch_model.named_parameters(), zero_model.module.named_parameters()
|
|
||||||
):
|
|
||||||
assert zero_name == torch_name
|
|
||||||
assert torch.allclose(zero_param.data, torch_param.data)
|
|
||||||
|
|
||||||
data = torch.randn(16, 4).cuda()
|
|
||||||
label = torch.randint(0, 4, (16,)).cuda()
|
label = torch.randint(0, 4, (16,)).cuda()
|
||||||
|
|
||||||
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
|
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||||
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
|
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
|
||||||
assert torch.allclose(torch_out, zero_out)
|
assert torch.allclose(zero_out, moe_out)
|
||||||
grad_handler.handle_gradient()
|
|
||||||
|
|
||||||
for (zero_name, zero_param), (torch_name, torch_param) in zip(
|
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||||
zero_model.module.named_parameters(), torch_model.named_parameters()
|
moe_model.module.named_parameters(), zero_model.module.named_parameters()
|
||||||
):
|
):
|
||||||
assert zero_name == torch_name
|
assert moe_name == zero_name
|
||||||
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
|
moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
|
||||||
if hasattr(zero_param, "moe_info"):
|
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
|
||||||
assert len(zero_grad_list) == 0
|
if hasattr(moe_param, "moe_info"):
|
||||||
assert torch.allclose(zero_param.grad, torch_param.grad)
|
assert len(moe_grad_list) == 0
|
||||||
|
if stage == 1:
|
||||||
|
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
|
||||||
|
else:
|
||||||
|
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
|
||||||
|
assert torch.allclose(
|
||||||
|
moe_param.grad, zero_grad, atol=1e-5
|
||||||
|
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
|
||||||
else:
|
else:
|
||||||
assert len(zero_grad_list) > 0
|
assert len(moe_grad_list) > 0
|
||||||
torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
|
assert len(moe_grad_list) == len(zero_grad_list)
|
||||||
if stage == 2:
|
for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
|
||||||
torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
|
assert torch.allclose(moe_grad, zero_grad)
|
||||||
assert len(zero_grad_list) == len(torch_grad_list)
|
|
||||||
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
|
|
||||||
assert torch.allclose(zero_grad, torch_grad)
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port, stage):
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
MOE_MANAGER.setup(parallel="EP")
|
|
||||||
seed_all(42 + rank)
|
seed_all(42 + rank)
|
||||||
run_zero_test(rank, world_size, stage=1)
|
run_zero_test(rank, stage=stage)
|
||||||
run_zero_test(rank, world_size, stage=2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [2])
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
|
@pytest.mark.parametrize("stage", [1, 2])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_moe_zero_model(world_size):
|
def test_moe_zero_model(world_size, stage):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size, stage=stage)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_zero_model(world_size=2)
|
test_moe_zero_model(world_size=2, stage=1)
|
||||||
|
|
|
@ -4,89 +4,80 @@ import torch
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
|
||||||
from colossalai.moe.manager import MOE_MANAGER
|
from colossalai.moe.manager import MOE_MANAGER
|
||||||
|
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
|
from colossalai.testing.random import seed_all
|
||||||
|
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
|
||||||
|
|
||||||
|
|
||||||
def split_ddp_grad(grad, world_size):
|
def run_zero_test(local_rank, stage=1):
|
||||||
with torch.no_grad():
|
|
||||||
grad = grad.clone().detach().flatten()
|
|
||||||
padding_size = (world_size - grad.numel() % world_size) % world_size
|
|
||||||
if padding_size > 0:
|
|
||||||
grad = torch.nn.functional.pad(grad, [0, padding_size])
|
|
||||||
splited_grad = grad.split(grad.numel() // world_size)
|
|
||||||
return splited_grad
|
|
||||||
|
|
||||||
|
|
||||||
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
|
|
||||||
model.train()
|
|
||||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
|
||||||
if criterion:
|
|
||||||
y = model(data)
|
|
||||||
loss = criterion(y, label)
|
|
||||||
else:
|
|
||||||
loss = model(data, label)
|
|
||||||
loss = loss.float()
|
|
||||||
|
|
||||||
if isinstance(model, LowLevelZeroModel):
|
|
||||||
optimizer.backward(loss)
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def run_zero_optim_test(local_rank, world_size, stage=1):
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
zero_model = MoeModel()
|
MOE_MANAGER.__init__()
|
||||||
zero_optimizer = torch.optim.Adam(zero_model.parameters())
|
MOE_MANAGER.setup(parallel="EP")
|
||||||
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
|
moe_model = MoeModel().bfloat16()
|
||||||
booster = Booster(plugin=plugin)
|
moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
|
||||||
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
|
moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||||
|
moe_booster = Booster(plugin=moe_plugin)
|
||||||
|
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
|
||||||
|
|
||||||
torch_model = MoeModel()
|
MOE_MANAGER.__init__()
|
||||||
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
|
MOE_MANAGER.setup(parallel=None)
|
||||||
torch_param.data.copy_(zero_param.data)
|
zero_model = MoeModel().bfloat16()
|
||||||
torch_optimizer = torch.optim.Adam(torch_model.parameters())
|
delete_moe_info(zero_model)
|
||||||
torch_model = torch_model.cuda()
|
sync_local_from_ep(zero_model, moe_model)
|
||||||
grad_handler = MoeGradientHandler(torch_model)
|
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
|
||||||
|
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
|
||||||
|
zero_booster = Booster(plugin=zero_plugin)
|
||||||
|
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
|
||||||
|
|
||||||
for _ in range(2):
|
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||||
data = torch.randn(16, 4).cuda() / (local_rank + 1)
|
moe_model.named_parameters(), zero_model.named_parameters()
|
||||||
label = torch.randint(0, 4, (16,)).cuda()
|
):
|
||||||
run_fwd_bwd(torch_model, data, label, criterion, None)
|
if ".experts." in moe_name:
|
||||||
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
continue
|
||||||
grad_handler.handle_gradient()
|
assert moe_name == zero_name
|
||||||
|
assert torch.allclose(
|
||||||
|
moe_param.data, zero_param.data
|
||||||
|
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
|
||||||
|
|
||||||
torch_optimizer.step()
|
for _ in range(1):
|
||||||
|
data = torch.randn(2, 4).bfloat16().cuda()
|
||||||
|
label = torch.randint(0, 4, (2,)).cuda()
|
||||||
|
|
||||||
|
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
|
||||||
|
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
|
||||||
|
assert torch.allclose(zero_out, moe_out)
|
||||||
|
moe_optimizer.step()
|
||||||
zero_optimizer.step()
|
zero_optimizer.step()
|
||||||
|
|
||||||
for (torch_name, torch_param), (zero_name, zero_param) in zip(
|
for (moe_name, moe_param), (zero_name, zero_param) in zip(
|
||||||
torch_model.named_parameters(), zero_model.named_parameters()
|
moe_model.named_parameters(), zero_model.named_parameters()
|
||||||
):
|
):
|
||||||
assert torch.allclose(
|
assert moe_name == zero_name
|
||||||
torch_param.data, zero_param.data
|
if is_moe_tensor(moe_param):
|
||||||
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
|
param_size = moe_param.shape[0]
|
||||||
|
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
|
||||||
|
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
|
||||||
|
|
||||||
torch_optimizer.zero_grad()
|
moe_optimizer.zero_grad()
|
||||||
zero_optimizer.zero_grad()
|
zero_optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port, stage):
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
MOE_MANAGER.setup(parallel="EP")
|
seed_all(42 + rank)
|
||||||
run_zero_optim_test(rank, world_size, stage=1)
|
run_zero_test(rank, stage=stage)
|
||||||
run_zero_optim_test(rank, world_size, stage=2)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [2])
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
|
@pytest.mark.parametrize("stage", [1, 2])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_moe_zero_optim(world_size):
|
def test_moe_zero_optim(world_size, stage):
|
||||||
spawn(run_dist, world_size)
|
spawn(run_dist, world_size, stage=stage)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_moe_zero_optim(world_size=2)
|
test_moe_zero_optim(world_size=2, stage=1)
|
||||||
|
|
Loading…
Reference in New Issue