[moe] init mixtral impl

pull/5372/head
Xuanlei Zhao 2023-12-14 17:52:05 +08:00 committed by ver217
parent c53ddda88f
commit 7d8e0338a4
28 changed files with 2025 additions and 223 deletions

Binary file not shown.

View File

@ -0,0 +1,205 @@
import logging
import os
from pathlib import Path
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
class MixtralMoECheckpointIO(MoECheckpintIO):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.no_grad()
def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
Preprocess state_dict before loading and slice the state_dict of MOE tensors.
"""
model_param_dict = dict(model.named_parameters())
for name, param in list(state_dict.items()):
if ".gate.weight" in name:
new_name = "module." + name.replace(".gate.weight", ".gate_weight")
state_dict[new_name] = state_dict.pop(name)
elif ".experts." in name:
# if is moe tensor
# in our moe module, expert is cat as one tensor
# but mixtral's experts is not cat
# we will insert the loaded expert into the position of cat tensor
# get model param
str_idx = name.index(".experts.")
expert_idx = int(name.split(".")[-3])
if ".w1." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
elif ".w2." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wo")
elif ".w3." in name:
model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
model_param_name = "module." + model_param_name
# skip for pipeline
if model_param_name not in model_param_dict:
continue
model_param = model_param_dict[model_param_name]
assert is_moe_tensor(model_param)
# get expert range
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = 8 // ep_size
expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
# insert new param
if expert_idx in expert_range:
new_param = model_param
new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
state_dict[model_param_name] = new_param
state_dict.pop(name)
else:
new_name = "module." + name
state_dict[new_name] = state_dict.pop(name)
dist.barrier()
return state_dict
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
state_dict = self.pre_load_model(model, state_dict)
missing_keys = []
load_state_dict_into_model(
model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True,
)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
name = name.replace("module.", "")
name = name.replace(".gate_weight", ".gate.weight")
if ".experts.wi_gate" in name:
for i in range(8):
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
_load(new_name)
elif ".experts.wi_up" in name:
for i in range(8):
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
_load(new_name)
elif ".experts.wo" in name:
for i in range(8):
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
_load(new_name)
else:
_load(name)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@torch.no_grad()
def pre_save_model(self, model: nn.Module) -> dict:
torch.cuda.empty_cache()
state_dict = model.state_dict()
for name, param in list(model.named_parameters()):
if ".gate_weight" in name:
new_name = name.replace(".gate_weight", ".gate.weight")
state_dict[new_name] = state_dict.pop(name).cpu()
elif ".experts." in name:
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [torch.zeros_like(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
assert all_param.shape[0] == 8
for i in range(8):
if ".wi_gate" in name:
new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
elif ".wi_up" in name:
new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
elif ".wo" in name:
new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
new_name = new_name.replace("module.", "")
new_param = all_param[i].transpose(-1, -2)
state_dict[new_name] = new_param.cpu()
state_dict.pop(name)
else:
state_dict[name] = param.cpu()
for name, param in list(state_dict.items()):
new_name = name.replace("module.", "")
state_dict[new_name] = state_dict.pop(name)
torch.cuda.empty_cache()
if self.pp_size > 1:
if self.dp_rank == 0:
# gather state_dict from every pp rank
# because ckpt is large, we split it into 10 parts
# and gather them one by one
new_state_dict = {}
state_dict_keys = list(state_dict.keys())
gap_key_num = min(30, len(state_dict_keys))
gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num
for i in range(gap_key_num):
cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
cur_state_dict = {}
for k in cur_keys:
cur_state_dict[k] = state_dict[k]
out = [None for _ in range(self.pp_size)]
dist.all_gather_object(out, cur_state_dict, group=self.pp_group)
if self.pp_rank == 0:
for o in out:
for k, v in o.items():
new_state_dict[k] = v.cpu()
state_dict = new_state_dict
dist.barrier()
return state_dict

View File

@ -0,0 +1,80 @@
import torch
import torch.nn as nn
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralSparseMoeBlock
from colossalai.lazy import LazyInitContext
from colossalai.moe import SparseMLP
class MixtralSparseMLP:
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
)
@staticmethod
def from_native_module(module: MixtralSparseMoeBlock, enable_kernel: bool) -> nn.Module:
r"""
Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
and optionally marking parameters for gradient aggregation.
Args:
module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
Returns:
nn.Module: Union[FastLayerNorm, FusedLayerNorm].
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
with torch.no_grad():
LazyInitContext.materialize(module)
# get the attributes of the module
moe_kwargs = dict(
num_experts=8,
hidden_size=module.hidden_dim,
intermediate_size=module.ffn_dim,
router_top_k=module.top_k,
router_norm=True,
router_loss=False,
# router_capacity_factor_train=
# router_capacity_factor_eval=
mlp_activation="silu",
mlp_gated=True,
# enable_load_balance=
# load_balance_tolerance=
# load_balance_beam_width=
# load_balance_group_swap_factor=
enable_kernel=enable_kernel,
# enable_comm_overlap=
# enable_hierarchical_comm=
return_gate_logits=True,
)
dtype = module.gate.weight.dtype
device = module.gate.weight.device
sparse_mlp = SparseMLP(**moe_kwargs).to(dtype).to(device)
return sparse_mlp
def replace_moe_layer(model: nn.Module, enable_kernel: bool = False) -> nn.Module:
"""
Reverse the replace layer operation
Args:
module (torch.nn.Module): The object of layer to shard
"""
if isinstance(model, MixtralDecoderLayer):
model.block_sparse_moe = MixtralSparseMLP.from_native_module(
model.block_sparse_moe, enable_kernel=enable_kernel
)
else:
for _, child in model.named_children():
replace_moe_layer(child, enable_kernel)

View File

@ -0,0 +1,543 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MoeCausalLMOutputWithPast,
_prepare_4d_causal_attention_mask,
load_balancing_loss_func,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.shard import ShardConfig
__all__ = ["MixtralPolicy", "MixtralForCausalLMPolicy"]
class MixtralPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
)
if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=FusedRMSNorm,
),
SubModuleReplacementDescription(
suffix="post_attention_layernorm",
target_module=FusedRMSNorm,
),
],
policy=policy,
target_key=MixtralDecoderLayer,
)
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="norm",
target_module=FusedRMSNorm,
),
policy=policy,
target_key=MixtralModel,
)
if self.shard_config.enable_flash_attention:
raise NotImplementedError("Flash attention has already been replaced in mixtral.")
return policy
def postprocess(self):
return self.model
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if self.pipeline_stage_manager:
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
assert self.pipeline_stage_manager is not None
if self.model.__class__.__name__ == "MixtralModel":
module = self.model
else:
module = self.model.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
return held_layers
class MixtralModelPolicy(MixtralPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralModel,
new_forward=MixtralPipelineForwards.mixtral_model_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
held_layers = super().get_held_layers()
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
return []
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
]
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=MixtralForCausalLM,
new_forward=MixtralPipelineForwards.mixtral_for_causal_lm_forward,
policy=policy,
)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []
class MixtralPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
"""
@staticmethod
def mixtral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
past_router_logits: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
if use_cache:
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
use_cache = False
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
output_attentions,
output_router_logits,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
output_router_logits,
use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if output_router_logits and past_router_logits is not None:
all_router_logits = past_router_logits + all_router_logits
if stage_manager.is_last_stage():
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
if v is not None
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
"past_router_logits": all_router_logits,
}
@staticmethod
def mixtral_for_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = True,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
past_router_logits: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MixtralForCausalLM
>>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False
if output_hidden_states:
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = MixtralPipelineForwards.mixtral_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
past_router_logits=past_router_logits,
)
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(outputs[-1], self.num_experts, self.num_experts_per_tok)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=None,
hidden_states=outputs[0],
attentions=None,
router_logits=outputs[-1],
)
else:
out = {}
hidden_states = outputs.get("hidden_states")
out["hidden_states"] = hidden_states
if output_router_logits:
out["past_router_logits"] = outputs["past_router_logits"]
return out

View File

@ -0,0 +1,102 @@
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
@torch.no_grad()
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
# pytorch ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
# saved ckpt
elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
# download
else:
ckpt_path = snapshot_download(ckpt_path)
booster.load_model(model, ckpt_path)
if optimizer is not None:
optimizer.sync_moe_master_param()
optimizer.update_master_params(model)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)

View File

@ -0,0 +1,138 @@
import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_model
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["ep"],
help="Parallel methos.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"checkpoint_io": MixtralMoECheckpointIO,
"zero_stage": 1,
}
mgr_dict = {}
if args.plugin == "ep":
dp_size = dist.get_world_size()
plugin = MoeHybridParallelPlugin(
pp_size=1,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
max_ep_size=dp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build mixtral model
config = MixtralConfig.from_pretrained(args.model_name)
config.num_local_experts = 1 # dont change this. it will not affect model
with skip_init():
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Replace moe
with skip_init():
replace_moe_layer(model)
model.eval()
coordinator.print_on_master(f"Finish replace moe module")
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, _, _, _, _ = booster.boost(model=model)
coordinator.print_on_master(f"Finish init booster")
# load ckpt
load_model(args.model_name, model, booster)
coordinator.print_on_master(f"Finish load ckpt")
text = ["Hello my name is", "1+1=?"]
tokenizer.pad_token = tokenizer.unk_token
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch.cuda.current_device())
outputs = model.module.generate(**inputs, max_new_tokens=20)
outputs = tokenizer.batch_decode(outputs)[0]
print(outputs)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,7 @@
NUM_GPU=2
MODEL="mistralai/Mixtral-8x7B-v0.1"
# ep
torchrun --standalone --nproc_per_node $NUM_GPU infer.py \
--model_name $MODEL \
--plugin "ep" \

View File

@ -0,0 +1,5 @@
colossalai >= 0.3.3
torch >= 1.8.1
transformers == 4.36.0
sentencepiece
datasets

View File

@ -0,0 +1,43 @@
from setuptools import find_packages, setup
def fetch_requirements(path):
with open(path, "r") as fd:
return [r.strip() for r in fd.readlines()]
def fetch_readme():
with open("README.md", encoding="utf-8") as f:
return f.read()
def fetch_version():
with open("version.txt", "r") as f:
return f.read().strip()
setup(
name="colossal_moe",
version=fetch_version(),
packages=find_packages(
exclude=(
"tests",
"benchmarks",
"*.egg-info",
)
),
description="Colossal-AI MoE",
long_description=fetch_readme(),
long_description_content_type="text/markdown",
license="Apache Software License 2.0",
url="https://github.com/hpcaitech",
install_requires=fetch_requirements("requirements.txt"),
python_requires=">=3.6",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: System :: Distributed Computing",
],
)

View File

@ -0,0 +1,185 @@
import os
import shutil
import pytest
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
}
def run_fwd_bwd(
model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
model.train()
if pipeline:
train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
y = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = y["loss"]
else:
if criterion:
y = model(data).logits
loss = criterion(y)
else:
loss = model(data, label)
loss = loss.float()
if optimizer is not None:
optimizer.backward(loss)
else:
loss.backward()
return y
def get_config():
config = MixtralConfig(
vocab_size=300,
hidden_size=32,
intermediate_size=16,
num_hidden_layers=2,
dropout_rate=0.0,
)
return config
def get_model(parallel):
config = get_config()
model = MixtralForCausalLM(config).to(torch.bfloat16)
replace_moe_layer(model)
optim = torch.optim.Adam(model.parameters())
args = dict(
precision="bf16",
tp_size=1,
zero_stage=1,
custom_policy=MixtralForCausalLMPolicy(),
checkpoint_io=MixtralMoECheckpointIO,
)
if parallel == "ep":
plugin = MoeHybridParallelPlugin(
pp_size=1,
**args,
)
elif parallel == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=2,
microbatch_size=1,
**args,
)
booster = Booster(plugin=plugin)
model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
return model, booster, optim
def _test_moe_checkpoint(parallel):
if dist.get_rank() == 0:
if os.path.exists("./tmp_ckpt1"):
shutil.rmtree("./tmp_ckpt1")
if os.path.exists("./tmp_ckpt2"):
shutil.rmtree("./tmp_ckpt2")
dist.barrier()
if parallel == None:
MOE_MANAGER.setup(
parallel=None,
)
elif parallel == "ep":
MOE_MANAGER.setup(
parallel="EP",
)
elif parallel == "hybrid":
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=1,
fixed_ep_size=2,
fixed_pp_size=2,
)
model1, booster1, optim1 = get_model(parallel)
model2, booster2, optim2 = get_model(parallel)
# param ckpt
# check not equal
try:
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
raise AssertionError("state_dict should not be equal")
except:
pass
# shard
booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
booster2.load_model(model2, "./tmp_ckpt1")
# check
check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
# optim ckpt
criterion = lambda x: x.mean()
data = torch.randint(0, 4, (2, 4)).cuda()
label = torch.randint(0, 4, (2,)).cuda()
if parallel == "hybrid":
kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
else:
kwargs = {}
run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
optim1.step()
optim1.zero_grad()
# shard
booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
dist.barrier()
booster2.load_optimizer(optim2, "./tmp_ckpt2")
# check
check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
if dist.get_rank() == 0:
shutil.rmtree("./tmp_ckpt1")
shutil.rmtree("./tmp_ckpt2")
def _run_dist(rank, world_size, port, parallel):
colossalai.launch(
config=dict(),
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
_test_moe_checkpoint(parallel)
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("parallel", ["ep", "hybrid"])
@rerun_if_address_is_in_use()
def test_moe_checkpoint(world_size, parallel):
spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
test_moe_checkpoint(world_size=4, parallel="hybrid")

View File

@ -0,0 +1,31 @@
import copy
import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
class Config:
def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.hidden_act = hidden_act
def test_moe_layer():
config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
mistral_moe = MixtralSparseMoeBlock(config).cuda()
colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()
data = torch.randn(2, 8, 4).cuda()
mistral_output = mistral_moe(data)[0]
colossal_output = colossal_moe(data)[0]
assert torch.allclose(
mistral_output, colossal_output
), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"
if __name__ == "__main__":
test_moe_layer()

View File

@ -0,0 +1,320 @@
import argparse
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER, apply_load_balance
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@torch.no_grad()
def get_global_loss(loss, booster):
global_loss = loss.clone().detach()
dist.all_reduce(tensor=global_loss, op=dist.ReduceOp.SUM, group=booster.plugin.dp_group)
global_loss.div_(booster.plugin.dp_size)
return global_loss
class RandomDataset(Dataset):
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 100, tokenizer=None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
def parse_args():
# basic settings
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
type=str,
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument(
"--plugin",
type=str,
default="hybrid",
choices=["hybrid"],
help="Parallel methods.",
)
parser.add_argument(
"--output_path",
type=str,
default="./outputs",
help="The path of your saved model after finetuning.",
)
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per dp group) for the training dataloader.",
)
parser.add_argument(
"--save_interval",
type=int,
default=1000,
help=" The interval (steps) of saving checkpoints.",
)
parser.add_argument(
"--precision",
type=str,
default="bf16",
choices=["fp32", "bf16", "fp16"],
help="The mixed precision training.",
)
parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
# optim
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
# lr scheduler
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
# zero stage for all plugins
parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
# hybrid plugin
parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
# kernel
parser.add_argument(
"--use_kernel",
action="store_true",
help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
)
parser.add_argument(
"--use_layernorm_kernel",
action="store_true",
help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
)
# load balance
parser.add_argument(
"--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
)
parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
# communicate overlap
parser.add_argument(
"--comm_overlap",
action="store_true",
help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
)
# hierarchical all-to-all
parser.add_argument(
"--hierarchical_alltoall",
action="store_true",
help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
# Set plugin
booster_kwargs = {}
hybrid_dict = {
"tp_size": 1,
"custom_policy": MixtralForCausalLMPolicy(),
"enable_fused_normalization": args.use_layernorm_kernel,
"enable_jit_fused": args.use_kernel,
"precision": args.precision,
"zero_stage": args.zero_stage,
"checkpoint_io": MixtralMoECheckpointIO,
}
mgr_dict = {}
if args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin(
pp_size=args.pp_size,
microbatch_size=args.microbatch_size,
**hybrid_dict,
)
MOE_MANAGER.setup(
parallel="EP",
mode="fixed",
fixed_dp_size=args.dp_size,
fixed_ep_size=args.ep_size,
fixed_pp_size=args.pp_size,
**mgr_dict,
)
else:
raise ValueError(f"Invalid plugin {args.plugin}")
coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
# Build Mixtral model
config = MixtralConfig.from_pretrained(args.model_name)
config.use_cache = False
config.num_local_experts = 1
model = MixtralForCausalLM(config)
model.num_experts = 8
model = model.to(torch.bfloat16) if args.precision == "bf16" else model.to(torch.float16)
model = model.to(get_current_device())
replace_moe_layer(model, enable_kernel=args.use_kernel)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
dataset = RandomDataset(num_samples=100, tokenizer=tokenizer)
collate_fn = None
dataloader = plugin.prepare_dataloader(
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
)
# Set optimizer
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
# Set lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * len(dataloader),
warmup_steps=args.warmup_steps
if args.warmup_steps is not None
else int(args.num_epochs * len(dataloader) * 0.025),
eta_min=0.1 * args.lr,
)
# Set booster
booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
if args.load_checkpoint is None:
load_model(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
else:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
for epoch in range(args.num_epoch):
model.train()
train_dataloader_iter = iter(dataloader)
total_len = len(train_dataloader_iter)
with tqdm(
range(total_len),
desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
disable=not coordinator.is_master() if use_pipeline == False else not is_pp_last_stage,
) as pbar:
for step in pbar:
if use_pipeline:
# Forward pass
outputs = booster.execute_pipeline(
train_dataloader_iter,
model,
lambda x, y: x.loss,
optimizer,
return_loss=True,
return_outputs=True,
)
# Backward and optimize
if is_pp_last_stage:
loss = outputs["loss"]
global_loss = get_global_loss(loss, booster)
if coordinator._local_rank == "0":
pbar.set_postfix({"Loss": global_loss.item()})
else:
# Forward pass
data = next(train_dataloader_iter)
data = move_to_cuda(data, torch.cuda.current_device())
outputs = model(**data)
loss = outputs["loss"]
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({"loss": loss.item()})
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Apply load balance
if (
args.load_balance
and args.load_balance_interval > 0
and (step + 1) % args.load_balance_interval == 0
):
coordinator.print_on_master(f"Apply load balance")
apply_load_balance(model, optimizer)
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
save_checkpoint(
args.output_path,
booster,
model,
optimizer,
lr_scheduler,
epoch,
step,
args.batch_size,
coordinator,
)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
# Finish training
coordinator.print_on_master(f"Finish training")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,19 @@
NUM_GPU=8
MODEL="mistralai/Mixtral-8x7B-v0.1"
SEQ_LENGTH=2048
BATCH_SIZE=1
LR=0.00001
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU \
colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile" \
train.py \
--num_epoch 1 \
--model_name $MODEL \
--plugin "hybrid" \
--batch_size $BATCH_SIZE \
--lr $LR \
--zero_stage 1 \
--pp_size 2 \
--dp_size 1 \
--ep_size 8 \

View File

@ -0,0 +1 @@
1.0.0

View File

@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication: bool = True, overlap_communication: bool = True,
use_ep_inside: bool = True, use_ep_inside: bool = True,
custom_policy: Policy = None, custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpintIO] = None,
) -> None: ) -> None:
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.enable_flash_attention = enable_flash_attention self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism self.enable_sequence_parallelism = enable_sequence_parallelism
self.checkpoint_io = checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance # we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
def get_checkpoint_io(self) -> MoECheckpintIO: def get_checkpoint_io(self) -> MoECheckpintIO:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io
def configure( def configure(

View File

@ -1,6 +1,7 @@
from .checkpoint import MoECheckpintIO from .checkpoint import MoECheckpintIO
from .experts import MLPExperts from .experts import MLPExperts
from .layers import SparseMLP from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
from .utils import NormalNoiseGenerator, UniformNoiseGenerator from .utils import NormalNoiseGenerator, UniformNoiseGenerator
@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator", "UniformNoiseGenerator",
"SparseMLP", "SparseMLP",
"MoECheckpintIO", "MoECheckpintIO",
"MOE_MANAGER",
"apply_load_balance",
] ]

View File

@ -224,6 +224,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
""" """
torch.cuda.empty_cache()
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return return
@ -265,6 +266,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"index located at {save_index_file}." f"index located at {save_index_file}."
) )
dist.barrier() dist.barrier()
torch.cuda.empty_cache()
# ======================================================== # ========================================================
# Abstract methods for optimizer loading/saving implementation # Abstract methods for optimizer loading/saving implementation
@ -332,10 +334,12 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param( def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, optimizer=None
): ):
if master_to_working_map is not None and id(param) in master_to_working_map: if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
return optimizer.param_info["param2id"][id(working_param)] return optimizer.param_info["param2id"][id(working_param)]
@ -347,7 +351,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
master_to_working_map = optimizer.get_master_to_working_map() master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups: for pg in optimizer.optim.param_groups:
for param in pg["params"]: for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
id_map[param_id] = param id_map[param_id] = param
# Read checkpoint index file. # Read checkpoint index file.
@ -371,14 +375,10 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
# ep extra group # ep param group
if MOE_MANAGER.parallel == "EP": if len(optimizer.optim.param_groups) > len(saved_groups):
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = optimizer.optim.param_groups[-1][ new_pg["params"] = optimizer.optim.param_groups[-1]["params"]
"params"
] # Only keep the parameters kept by current pipeline stage.
for param in new_pg["params"]:
param.data = param.data.to(torch.float32)
updated_groups.append(new_pg) updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups}) optimizer.optim.__dict__.update({"param_groups": updated_groups})
@ -389,7 +389,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for param in pg["params"]: for param in pg["params"]:
if param is None: if param is None:
continue continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map) param_id = _get_param_id_from_optimizer_param(param, master_to_working_map, optimizer)
if param_id not in weight_map: if param_id not in weight_map:
continue continue
filename = weight_map[param_id] filename = weight_map[param_id]
@ -400,27 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename) file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
# Then shard the loaded optimizer states if using tp/zero.
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device="cpu",
inplace=True,
)
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename) loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim) sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master(): if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@ -576,6 +583,8 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
if master_to_working_map is not None and id(param) in master_to_working_map: if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)] working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else: else:
working_param = param working_param = param
@ -618,6 +627,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
prefix (str): Perfix of file to save prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors size_per_shard (int): Max file size of each file shard that store state tensors
""" """
torch.cuda.empty_cache()
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint): if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -723,6 +733,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
f"You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}." f"index located at {final_index_file_path}."
) )
torch.cuda.empty_cache()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
""" """

View File

@ -67,7 +67,11 @@ class MLPExperts(nn.Module):
self.ep_size = 1 self.ep_size = 1
if gated: if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2)) self.wi_gate = nn.Parameter(
torch.empty(
num_experts, hidden_size, intermediate_size * 2 if activation == "swiglu" else intermediate_size
)
)
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else: else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))

View File

@ -51,6 +51,8 @@ class SparseMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
router_top_k: int = 1, router_top_k: int = 1,
router_loss: bool = True,
router_norm: bool = False,
router_capacity_factor_train: float = 1.25, router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0, router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4, router_min_capacity: int = 4,
@ -65,15 +67,19 @@ class SparseMLP(nn.Module):
enable_kernel: bool = False, enable_kernel: bool = False,
enable_comm_overlap: bool = False, enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False, enable_hierarchical_comm: bool = False,
return_gate_logits: bool = False,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_experts = num_experts self.num_experts = num_experts
self.gated = mlp_gated self.gated = mlp_gated
self.return_gate_logits = return_gate_logits
self.enable_kernel = enable_kernel self.enable_kernel = enable_kernel
self.enable_comm_overlap = enable_comm_overlap self.enable_comm_overlap = enable_comm_overlap
self.expert_parallel = MOE_MANAGER.get_parallel() self.expert_parallel = MOE_MANAGER.get_parallel()
self.router_loss = router_loss
self.router_norm = router_norm
# moe router # moe router
noisy_func = get_noise_generator(router_noisy_policy, num_experts) noisy_func = get_noise_generator(router_noisy_policy, num_experts)
@ -150,9 +156,8 @@ class SparseMLP(nn.Module):
tokens = inputs.reshape(-1, self.hidden_size) tokens = inputs.reshape(-1, self.hidden_size)
# the data type of the inputs in the gating should be fp32 # the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float) gate_logits = F.linear(tokens, self.gate_weight)
fp32_weight = self.gate_weight.to(torch.float) gate_output = gate_logits.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# update expert load # update expert load
if self.enable_load_balance == True: if self.enable_load_balance == True:
@ -165,7 +170,12 @@ class SparseMLP(nn.Module):
# the result from the router # the result from the router
used_capacity, *route_result_list = self.router( used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) inputs=gate_output,
use_kernel=self.enable_kernel,
ep_group=self.ep_group,
use_loss=self.router_loss,
use_norm=self.router_norm,
)
# dispatch_data: (num_experts, capacity, hidden_size) # dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel: if self.enable_kernel:
@ -177,22 +187,15 @@ class SparseMLP(nn.Module):
# expert_output: (num_groups, num_experts, capacity, hidden_size) # expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP": if self.expert_parallel == "EP":
expert_output = self._ep_process( expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel == "TP": elif self.expert_parallel == "TP":
expert_output = self._tp_process( expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap)
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel is None: elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data) expert_output = self._local_process(dispatch_data)
else: else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n" raise NotImplementedError(
"Please use Experts build function.") "This kind of communication has not been implemented yet.\n" "Please use Experts build function."
)
if self.enable_kernel: if self.enable_kernel:
expert_output = expert_output.reshape(-1, self.hidden_size) expert_output = expert_output.reshape(-1, self.hidden_size)
@ -204,7 +207,11 @@ class SparseMLP(nn.Module):
ans = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)
ans = ans.reshape(inputs.shape) ans = ans.reshape(inputs.shape)
return ans
if self.return_gate_logits:
return ans, gate_logits
else:
return ans
def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0) expert_in = expert_in.unsqueeze(0)
@ -212,10 +219,7 @@ class SparseMLP(nn.Module):
return expert_out return expert_out
def _ep_process( def _ep_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Expert Parallel Expert Parallel
@ -228,10 +232,14 @@ class SparseMLP(nn.Module):
""" """
if not overlap or dist.get_world_size(self.ep_group) == 1: if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None: if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_input = HierarchicalAllToAll.apply(
dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank
)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank) expert_output = HierarchicalAllToAll.apply(
expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank
)
return expert_output return expert_output
else: else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
@ -249,7 +257,7 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape) dispatch_data = dispatch_data.reshape(*input_shape)
@ -262,13 +270,15 @@ class SparseMLP(nn.Module):
for i in range(NUM_CHUNK + NUM_STAGES - 1): for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None: if expert_out is not None:
expert_out.handle.wait() expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data output[:, :, offset : offset + chunk_size, :] = expert_out.data
offset += chunk_size offset += chunk_size
expert_out = None expert_out = None
# all2all last output # all2all last output
if _expert_out is not None: if _expert_out is not None:
expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None _expert_out = None
# all2all next input # all2all next input
@ -288,10 +298,7 @@ class SparseMLP(nn.Module):
return output return output
def _tp_process( def _tp_process(
self, self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
without overlap: without overlap:
@ -326,8 +333,9 @@ class SparseMLP(nn.Module):
NUM_CHUNK = 4 NUM_CHUNK = 4
NUM_STAGES = 4 NUM_STAGES = 4
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ assert (
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0) chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data) output = torch.empty_like(dispatch_data)

View File

@ -150,7 +150,14 @@ class Top1Router(MoeRouter):
high=torch.tensor(1.0, device=get_accelerator().get_current_device()), high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_loss: bool = False,
use_norm: bool = False,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@ -207,7 +214,7 @@ class Top1Router(MoeRouter):
weight = mask * probs.type_as(inputs) weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool() sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask return used_capacity, combine_weights, sec_mask, probs
class Top2Router(MoeRouter): class Top2Router(MoeRouter):
@ -240,7 +247,14 @@ class Top2Router(MoeRouter):
drop_tks=drop_tks, drop_tks=drop_tks,
) )
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(
self,
inputs: torch.Tensor,
use_kernel: bool = False,
ep_group: Optional[ProcessGroup] = None,
use_norm: bool = False,
use_loss: bool = True,
) -> Tuple:
""" """
Args: Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
@ -257,6 +271,10 @@ class Top2Router(MoeRouter):
assert inputs.dtype == torch.float assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1) probs = F.softmax(inputs, dim=-1)
if use_norm:
routing_weights, _ = torch.topk(probs, 2, dim=-1)
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
num_experts = probs.size(-1) num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape) capacity = self.get_capacity(inputs.shape)
@ -270,10 +288,11 @@ class Top2Router(MoeRouter):
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss # calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) if use_loss:
self.set_aux_loss(probs, expert_indices, num_experts) expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_z_loss(inputs) self.set_aux_loss(probs, expert_indices, num_experts)
self.pop_router_loss() self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None: if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0)) max_num = torch.max(torch.sum(cmask, dim=0))

View File

@ -83,6 +83,8 @@ def get_activation(act: str) -> Callable:
return torch.nn.GELU() return torch.nn.GELU()
elif act == "swiglu": elif act == "swiglu":
return SwiGLU return SwiGLU
elif act == "silu":
return torch.nn.SiLU()
else: else:
raise NotImplementedError("Unsupported activation function") raise NotImplementedError("Unsupported activation function")

View File

@ -141,7 +141,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# because they have different parallel strategy # because they have different parallel strategy
# so we need to store them separately in param_groups # so we need to store them separately in param_groups
# instead of working_groups # instead of working_groups
moe_params = list() self.working_moe_params = list()
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
@ -153,7 +153,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
# skip moe param # skip moe param
if is_moe_tensor(param): if is_moe_tensor(param):
moe_params.append(param) self.working_moe_params.append(param)
continue continue
group_params.append(param) group_params.append(param)
@ -168,13 +168,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# managed by this data parallel rank # managed by this data parallel rank
param_group["params"] = master_param_current_rank param_group["params"] = master_param_current_rank
# if there are moe params, store in additional group in optim # if there are moe params, store in addtional group in optim
if len(moe_params) > 0: if len(self.working_moe_params) > 0:
self._sync_master_param = False
param_group = dict() param_group = dict()
# create fp32 master param
for key, value in self.optim.param_groups[0].items(): for key, value in self.optim.param_groups[0].items():
if key != "params": if key != "params":
param_group[key] = value param_group[key] = value
param_group["params"] = moe_params self.master_moe_params = []
for param in self.working_moe_params:
self.master_moe_params.append(param.clone().to(torch.float32).detach())
# create mapping from master to working for optimizer io
self.moe_master_to_working_map = {}
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
self.moe_master_to_working_map[id(master_moe_param)] = working_moe_param
# add to optim
param_group["params"] = self.master_moe_params
self.optim.param_groups.append(param_group) self.optim.param_groups.append(param_group)
# initialize communication stream for # initialize communication stream for
@ -593,24 +603,40 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# update the params in the optimizer # update the params in the optimizer
self.optim.param_groups[group_id]["params"] = real_master_params[group_id] self.optim.param_groups[group_id]["params"] = real_master_params[group_id]
# update param for moe ep
# move grad to master param and compute norm
if len(self.working_moe_params) > 0:
moe_grads = []
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
if master_moe_param.grad is not None:
raise RuntimeError("Moe param should not have grad here")
grad = working_moe_param.grad
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(master_moe_param.dtype).to(master_moe_param.device)
master_moe_param.grad = grad
working_moe_param.grad = None
moe_grads.append(grad)
grad_partition_groups.append(grad)
norm_group = self._compute_grad_norm(gradients=moe_grads)
norm_groups.append(norm_group)
self.optim.param_groups[-1]["params"] = self.master_moe_params
del moe_grads
# unscale and clip grads # unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups) global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm) self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# TODO: we should store master param for ep
if len(self.param_groups) > len(self._working_param_groups):
for param in self.param_groups[-1]["params"]:
param.data = param.data.to(torch.float32)
param.grad = param.grad.to(torch.float32)
# update the parameters # update the parameters
self.optim.step() self.optim.step()
# release the moe gradm # release moe grad
if len(self.param_groups) > len(self._working_param_groups): if len(self.working_moe_params) > 0:
for param in self.param_groups[-1]["params"]: for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
param.grad = None master_moe_param.grad = None
param.data = param.data.to(self._dtype) working_moe_param.data = (
master_moe_param.data.to(working_moe_param.device).to(working_moe_param.dtype).detach()
)
# release the grad # release the grad
grad_partition_groups = [] grad_partition_groups = []
@ -640,6 +666,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
def sync_moe_master_param(self):
for master_moe_param, working_moe_param in zip(self.master_moe_params, self.working_moe_params):
master_moe_param.data = working_moe_param.data.clone().to(torch.float32).detach()
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r""" r"""
Compute and return the gradient norm for gradient clipping. Compute and return the gradient norm for gradient clipping.

View File

@ -1,13 +1,22 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.testing import assert_close
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
from colossalai.legacy.registry import GRADIENT_HANDLER from colossalai.legacy.registry import GRADIENT_HANDLER
from colossalai.moe import SparseMLP from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.moe.utils import get_moe_epsize_param_dict
from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size
def delete_moe_info(model):
for _, param in model.named_parameters():
if hasattr(param, "moe_info"):
delattr(param, "moe_info")
class MoeModel(nn.Module): class MoeModel(nn.Module):
@ -85,6 +94,74 @@ def assert_not_equal_in_group(tensor, process_group=None):
for i in range(world_size - 1): for i in range(world_size - 1):
a = tensor_list[i] a = tensor_list[i]
b = tensor_list[i + 1] b = tensor_list[i + 1]
assert not torch.allclose(a, b), \ assert not torch.allclose(a, b), (
(f"expected tensors on rank {i} and {i + 1} not to be equal " f"expected tensors on rank {i} and {i + 1} not to be equal " f"but they are, {a} vs {b}"
f"but they are, {a} vs {b}") )
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
"""Sync the parameters of tp model from ep model
Args:
local_model (MoeModule)
ep_model (MoeModule)
"""
for (local_name, local_param), (ep_name, ep_param) in zip(
local_model.named_parameters(), ep_model.named_parameters()
):
assert local_name in ep_name, print(f"{local_name} != {ep_name}")
if "experts" not in local_name:
if assert_grad_flag:
assert torch.allclose(local_param, ep_param), f"local_param: {local_param}, ep_param: {ep_param}"
assert torch.allclose(local_param.grad, ep_param.grad)
else:
local_param.data.copy_(ep_param.data)
continue
# gather param from ep model
param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
all_param = torch.cat(param_list, dim=0)
if assert_grad_flag:
grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
all_grad = torch.cat(grad_list, dim=0)
if assert_grad_flag:
assert torch.allclose(local_param, all_param)
assert torch.allclose(local_param.grad, all_grad)
else:
local_param.data.copy_(all_param.data)
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype).to(a.device)
assert_close(a, b, rtol=rtol, atol=atol)

View File

@ -4,102 +4,75 @@ import torch
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel from tests.test_moe.moe_utils import MoeModel, delete_moe_info, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size): def run_zero_test(local_rank, stage=1):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel() MOE_MANAGER.__init__()
optimizer = torch.optim.Adam(zero_model.parameters()) MOE_MANAGER.setup(parallel="EP")
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") moe_model = MoeModel().bfloat16()
booster = Booster(plugin=plugin) moe_optimizer = torch.optim.Adam(moe_model.parameters())
zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer) moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
torch_model = MoeModel() MOE_MANAGER.__init__()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): MOE_MANAGER.setup(parallel=None)
torch_param.data.copy_(zero_param.data) zero_model = MoeModel().bfloat16()
torch_model = torch_model.cuda() delete_moe_info(zero_model)
grad_handler = MoeGradientHandler(torch_model) zero_optimizer = torch.optim.Adam(zero_model.parameters())
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
sync_local_from_ep(zero_model, moe_model)
# assert zero model data = torch.randn(16, 4).bfloat16().cuda()
for (torch_name, torch_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.module.named_parameters()
):
assert zero_name == torch_name
assert torch.allclose(zero_param.data, torch_param.data)
data = torch.randn(16, 4).cuda()
label = torch.randint(0, 4, (16,)).cuda() label = torch.randint(0, 4, (16,)).cuda()
torch_out = run_fwd_bwd(torch_model, data, label, criterion, None) zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer) moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
assert torch.allclose(torch_out, zero_out) assert torch.allclose(zero_out, moe_out)
grad_handler.handle_gradient()
for (zero_name, zero_param), (torch_name, torch_param) in zip( for (moe_name, moe_param), (zero_name, zero_param) in zip(
zero_model.module.named_parameters(), torch_model.named_parameters() moe_model.module.named_parameters(), zero_model.module.named_parameters()
): ):
assert zero_name == torch_name assert moe_name == zero_name
zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param)) moe_grad_list = moe_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(moe_param))
if hasattr(zero_param, "moe_info"): zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
assert len(zero_grad_list) == 0 if hasattr(moe_param, "moe_info"):
assert torch.allclose(zero_param.grad, torch_param.grad) assert len(moe_grad_list) == 0
if stage == 1:
zero_grad = zero_grad_list[local_rank].view(moe_param.grad.shape)
else:
zero_grad = zero_grad_list[0].view(moe_param.grad.shape)
assert torch.allclose(
moe_param.grad, zero_grad, atol=1e-5
), f"zero grad:\n{moe_param.grad}\ntorch grad:\n{zero_grad}\nmax diff: {(moe_param.grad - zero_grad).abs().max()}, mean diff: {(moe_param.grad - zero_grad).abs().mean()}"
else: else:
assert len(zero_grad_list) > 0 assert len(moe_grad_list) > 0
torch_grad_list = split_ddp_grad(torch_param.grad, world_size) assert len(moe_grad_list) == len(zero_grad_list)
if stage == 2: for moe_grad, zero_grad in zip(moe_grad_list, zero_grad_list):
torch_grad_list = torch_grad_list[local_rank : local_rank + 1] assert torch.allclose(moe_grad, zero_grad)
assert len(zero_grad_list) == len(torch_grad_list)
for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
assert torch.allclose(zero_grad, torch_grad)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP")
seed_all(42 + rank) seed_all(42 + rank)
run_zero_test(rank, world_size, stage=1) run_zero_test(rank, stage=stage)
run_zero_test(rank, world_size, stage=2)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_model(world_size): def test_moe_zero_model(world_size, stage):
spawn(run_dist, world_size) spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_zero_model(world_size=2) test_moe_zero_model(world_size=2, stage=1)

View File

@ -4,89 +4,80 @@ import torch
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel from colossalai.testing.random import seed_all
from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, run_fwd_bwd, sync_local_from_ep
def split_ddp_grad(grad, world_size): def run_zero_test(local_rank, stage=1):
with torch.no_grad():
grad = grad.clone().detach().flatten()
padding_size = (world_size - grad.numel() % world_size) % world_size
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
splited_grad = grad.split(grad.numel() // world_size)
return splited_grad
def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
return y
def run_zero_optim_test(local_rank, world_size, stage=1):
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
zero_model = MoeModel() MOE_MANAGER.__init__()
zero_optimizer = torch.optim.Adam(zero_model.parameters()) MOE_MANAGER.setup(parallel="EP")
plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") moe_model = MoeModel().bfloat16()
booster = Booster(plugin=plugin) moe_optimizer = torch.optim.Adam(moe_model.parameters(), lr=1.0)
zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) moe_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
moe_booster = Booster(plugin=moe_plugin)
moe_model, moe_optimizer, _, _, _ = moe_booster.boost(moe_model, moe_optimizer)
torch_model = MoeModel() MOE_MANAGER.__init__()
for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): MOE_MANAGER.setup(parallel=None)
torch_param.data.copy_(zero_param.data) zero_model = MoeModel().bfloat16()
torch_optimizer = torch.optim.Adam(torch_model.parameters()) delete_moe_info(zero_model)
torch_model = torch_model.cuda() sync_local_from_ep(zero_model, moe_model)
grad_handler = MoeGradientHandler(torch_model) zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1.0)
zero_plugin = LowLevelZeroPlugin(stage=stage, precision="bf16")
zero_booster = Booster(plugin=zero_plugin)
zero_model, zero_optimizer, _, _, _ = zero_booster.boost(zero_model, zero_optimizer)
for _ in range(2): for (moe_name, moe_param), (zero_name, zero_param) in zip(
data = torch.randn(16, 4).cuda() / (local_rank + 1) moe_model.named_parameters(), zero_model.named_parameters()
label = torch.randint(0, 4, (16,)).cuda() ):
run_fwd_bwd(torch_model, data, label, criterion, None) if ".experts." in moe_name:
run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) continue
grad_handler.handle_gradient() assert moe_name == zero_name
assert torch.allclose(
moe_param.data, zero_param.data
), f"{moe_name}\ntorch_param {moe_param.data}\nzero_param {zero_param.data}"
torch_optimizer.step() for _ in range(1):
data = torch.randn(2, 4).bfloat16().cuda()
label = torch.randint(0, 4, (2,)).cuda()
moe_out = run_fwd_bwd(moe_model, data, label, criterion, moe_optimizer)
zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
assert torch.allclose(zero_out, moe_out)
moe_optimizer.step()
zero_optimizer.step() zero_optimizer.step()
for (torch_name, torch_param), (zero_name, zero_param) in zip( for (moe_name, moe_param), (zero_name, zero_param) in zip(
torch_model.named_parameters(), zero_model.named_parameters() moe_model.named_parameters(), zero_model.named_parameters()
): ):
assert torch.allclose( assert moe_name == zero_name
torch_param.data, zero_param.data if is_moe_tensor(moe_param):
), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" param_size = moe_param.shape[0]
zero_param = zero_param[local_rank * param_size : (local_rank + 1) * param_size]
loose_close(moe_param.data, zero_param.data, dtype=moe_param.dtype)
torch_optimizer.zero_grad() moe_optimizer.zero_grad()
zero_optimizer.zero_grad() zero_optimizer.zero_grad()
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port, stage):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
MOE_MANAGER.setup(parallel="EP") seed_all(42 + rank)
run_zero_optim_test(rank, world_size, stage=1) run_zero_test(rank, stage=stage)
run_zero_optim_test(rank, world_size, stage=2)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("stage", [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size): def test_moe_zero_optim(world_size, stage):
spawn(run_dist, world_size) spawn(run_dist, world_size, stage=stage)
if __name__ == "__main__": if __name__ == "__main__":
test_moe_zero_optim(world_size=2) test_moe_zero_optim(world_size=2, stage=1)