mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement * [misc] fix typo, remove redundant file (#5867) * [misc] fix typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] deepseek support & unit test * [misc] remove debug code & useless print * [misc] fix typos (#5872) * [Feature] remove modeling file, use auto config. (#5884) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [Deepseek] remove redundant code (#5888) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [Feature/deepseek] resolve comment. (#5889) * [misc] fix typos * [Feature] deepseek support via auto model, remove modeling file * [misc] delete useless file * [misc] fix typos * [misc] remove redundant code * [misc] mv module replacement into if branch * [misc] add some warning message and modify some code in unit test * [misc] fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5894/head
parent
7997683aac
commit
3420921101
|
@ -147,7 +147,7 @@ class ProcessGroupMesh:
|
|||
ProcessGroup: The process group with the given ranks.
|
||||
"""
|
||||
ranks_in_group = sorted(ranks_in_group)
|
||||
if tuple(ranks_in_group) not in self._group_to_ranks:
|
||||
if tuple(ranks_in_group) not in self._ranks_to_group:
|
||||
group = dist.new_group(ranks_in_group, backend=backend)
|
||||
self._ranks_to_group[tuple(ranks_in_group)] = group
|
||||
self._group_to_ranks[group] = tuple(ranks_in_group)
|
||||
|
|
|
@ -0,0 +1,429 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import is_flash_attn_2_available, logging
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
from colossalai.shardformer.shard.utils import set_tensors_to_none
|
||||
|
||||
|
||||
# copied from modeling_deepseek.py
|
||||
class AddAuxiliaryLoss(torch.autograd.Function):
|
||||
"""
|
||||
The trick function of adding auxiliary (aux) loss,
|
||||
which includes the gradient of the aux loss during backpropagation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, loss):
|
||||
assert loss.numel() == 1
|
||||
ctx.dtype = loss.dtype
|
||||
ctx.required_aux_loss = loss.requires_grad
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_loss = None
|
||||
if ctx.required_aux_loss:
|
||||
grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
|
||||
return grad_output, grad_loss
|
||||
|
||||
|
||||
class EPDeepseekMoE(nn.Module):
|
||||
def __init__(self):
|
||||
super(EPDeepseekMoE, self).__init__()
|
||||
|
||||
def setup_ep(self, ep_group: ProcessGroup):
|
||||
ep_group = ep_group
|
||||
self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
|
||||
self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
|
||||
self.num_experts = self.config.n_routed_experts
|
||||
assert self.num_experts % self.ep_size == 0
|
||||
self.ep_group = ep_group
|
||||
self.num_experts_per_ep = self.num_experts // self.ep_size
|
||||
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
|
||||
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
|
||||
set_tensors_to_none(self.experts, exclude=set(held_experts))
|
||||
for p in self.experts.parameters():
|
||||
p.ep_group = ep_group
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
|
||||
LazyInitContext.materialize(module)
|
||||
if module.__class__.__name__ == "DeepseekMLP":
|
||||
return module
|
||||
module.__class__ = EPDeepseekMoE
|
||||
assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
|
||||
module.setup_ep(kwargs["ep_group"])
|
||||
return module
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
|
||||
topk_experts_idx, topk_experts_weight, aux_loss = self.gate(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # [t0, t1, t2 ...]
|
||||
hidden_states = hidden_states.repeat_interleave(
|
||||
self.num_experts_per_tok, dim=0
|
||||
) # after repeat_interleave: [t0 t0 t1 t1 t2 t2 ... ]
|
||||
|
||||
flat_topk_experts_idx = topk_experts_idx.view(-1) # [e0 e1 e2 ...]
|
||||
# The elements of flat_topk_token_idx are token ids, which are arranged in ascending order of expert ids.
|
||||
flat_topk_token_idx = flat_topk_experts_idx.argsort()
|
||||
|
||||
# Now we adjust the order of the hidden states, also in ascending order of expert id
|
||||
dispatch_states = hidden_states[flat_topk_token_idx]
|
||||
input_split_sizes = flat_topk_experts_idx.bincount(minlength=self.num_experts) # [n0, n1, n2, n3]
|
||||
output_split_sizes = torch.zeros_like(input_split_sizes)
|
||||
|
||||
# [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
|
||||
dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
|
||||
|
||||
input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
|
||||
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
|
||||
output_states = MoeInGradScaler.apply(output_states, self.ep_size)
|
||||
|
||||
if output_states.size(0) > 0:
|
||||
if self.num_experts_per_ep == 1:
|
||||
expert = self.experts[self.expert_start_idx]
|
||||
output_states = expert(output_states)
|
||||
else:
|
||||
output_states_splits = output_states.split(output_split_sizes.tolist())
|
||||
output_states_list = []
|
||||
for i, split_states in enumerate(output_states_splits):
|
||||
if split_states.size(0) == 0: # no token routed to this experts
|
||||
continue
|
||||
expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
|
||||
split_states = expert(split_states)
|
||||
output_states_list.append(split_states)
|
||||
output_states = torch.cat(output_states_list)
|
||||
output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
|
||||
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
|
||||
recover_token_idx = torch.empty_like(flat_topk_token_idx)
|
||||
recover_token_idx[flat_topk_token_idx] = torch.arange(
|
||||
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
|
||||
)
|
||||
|
||||
output_hidden_states = dispatch_states[recover_token_idx] # t0 t0 t1 t1 t2 t2
|
||||
output_hidden_states = output_hidden_states.view(-1, self.num_experts_per_tok, orig_shape[-1])
|
||||
output_hidden_states = (output_hidden_states * topk_experts_weight[:, :, None]).sum(dim=-2) # (B*S, h)
|
||||
output_hidden_states = output_hidden_states.view(*orig_shape)
|
||||
output_hidden_states = AddAuxiliaryLoss.apply(output_hidden_states, aux_loss)
|
||||
if self.config.n_shared_experts is not None:
|
||||
output_hidden_states = output_hidden_states + self.shared_experts(identity)
|
||||
return output_hidden_states
|
||||
|
||||
|
||||
class DeepseekPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def deepseek_model_forward(
|
||||
self: "DeepseekModel",
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if stage_manager.is_first_stage():
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
|
||||
use_cache = False
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if is_flash_attn_2_available():
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
hidden_states,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = (layer_outputs[2 if output_attentions else 1],)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
# always return dict for imediate stage
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def deepseek_for_causal_lm_forward(
|
||||
self: "DeepseekForCausalLM",
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MixtralForCausalLM
|
||||
|
||||
>>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = DeepseekPipelineForwards.deepseek_model_forward(
|
||||
self.model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
past_key_values = None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=None,
|
||||
hidden_states=outputs[0],
|
||||
attentions=None,
|
||||
)
|
||||
else:
|
||||
out = {}
|
||||
hidden_states = outputs.get("hidden_states")
|
||||
out["hidden_states"] = hidden_states
|
||||
return out
|
|
@ -160,6 +160,13 @@ _POLICY_LIST = {
|
|||
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||
),
|
||||
# Deepseek
|
||||
"transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
|
||||
file_name="deepseek", class_name="DeepseekModelPolicy"
|
||||
),
|
||||
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
|
||||
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
|
||||
),
|
||||
# Falcon
|
||||
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
|
||||
file_name="falcon", class_name="FalconModelPolicy"
|
||||
|
@ -252,7 +259,6 @@ def get_autopolicy(model: nn.Module) -> Policy:
|
|||
"""
|
||||
full_name = _fullname(model)
|
||||
policy_location = _POLICY_LIST.get(full_name, None)
|
||||
|
||||
if policy_location is None:
|
||||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["DeepseekPolicy", "DeepseekForCausalLMPolicy"]
|
||||
|
||||
|
||||
class DeepseekPolicy(Policy):
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
raise NotImplementedError(
|
||||
"Deepseek dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
|
||||
|
||||
if getattr(self.shard_config, "ep_group", None) is not None:
|
||||
# expert parallel
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=EPDeepseekMoE,
|
||||
kwargs={"ep_group": self.shard_config.ep_group},
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key="DeepseekDecoderLayer",
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="input_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="post_attention_layernorm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key="DeepseekDecoderLayer",
|
||||
)
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key="DeepseekModel",
|
||||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
warnings.warn(
|
||||
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
|
||||
)
|
||||
self.shard_config.enable_flash_attention = False
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "DeepseekModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == "DeepseekModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
|
||||
if stage_manager.is_first_stage():
|
||||
held_layers.append(module.embed_tokens)
|
||||
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
||||
held_layers.extend(module.layers[start_idx:end_idx])
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(module.norm)
|
||||
|
||||
return held_layers
|
||||
|
||||
|
||||
class DeepseekModelPolicy(DeepseekPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls="DeepseekModel",
|
||||
new_forward=DeepseekPipelineForwards.deepseek_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
held_layers = super().get_held_layers()
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
"""No shared params in llama model"""
|
||||
return []
|
||||
|
||||
|
||||
class DeepseekForCausalLMPolicy(DeepseekPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
"DeepseekForCausalLM": ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(
|
||||
model_cls="DeepseekForCausalLM",
|
||||
new_forward=DeepseekPipelineForwards.deepseek_for_causal_lm_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
held_layers = super().get_held_layers()
|
||||
if stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
deepseek_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(deepseek_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: deepseek_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
return []
|
|
@ -192,16 +192,16 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
llama_model = self.model.model
|
||||
mixtral_model = self.model.model
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if (
|
||||
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
id(mixtral_model.embed_tokens.weight) == id(self.model.lm_head.weight)
|
||||
and self.pipeline_stage_manager.num_stages > 1
|
||||
):
|
||||
# tie weights
|
||||
return [
|
||||
{
|
||||
0: llama_model.embed_tokens.weight,
|
||||
0: mixtral_model.embed_tokens.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
|
||||
}
|
||||
]
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.testing import assert_close
|
||||
from transformers import AutoConfig, AutoModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.shardformer.modeling.deepseek import EPDeepseekMoE
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
hidden_size = 8
|
||||
top_k = 2
|
||||
|
||||
|
||||
def check_deepseek_moe_layer():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
precision="bf16",
|
||||
tp_size=1,
|
||||
pp_size=1,
|
||||
ep_size=dist.get_world_size(),
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
"deepseek-ai/deepseek-moe-16b-base",
|
||||
num_hidden_layers=1,
|
||||
n_routed_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
first_k_dense_replace=0,
|
||||
num_attention_heads=2,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
# get the moe layer in auto model
|
||||
orig_model = AutoModel.from_config(config, trust_remote_code=True).layers[0].mlp.cuda()
|
||||
x = torch.rand(1, tokens, hidden_size, requires_grad=True).cuda()
|
||||
orig_output = orig_model(x)
|
||||
model = deepcopy(orig_model)
|
||||
model = EPDeepseekMoE.from_native_module(model, ep_group=plugin.ep_group)
|
||||
ep_output = model(x)
|
||||
assert_close(orig_output, ep_output)
|
||||
orig_loss = orig_output.mean()
|
||||
orig_loss.backward()
|
||||
ep_loss = ep_output.mean()
|
||||
ep_loss.backward()
|
||||
assert_close(orig_loss, ep_loss)
|
||||
name_to_p = {n: p for n, p in orig_model.named_parameters()}
|
||||
for n, ep_p in model.named_parameters():
|
||||
p = name_to_p[n]
|
||||
if ep_p.grad is not None:
|
||||
assert_close(p.grad, ep_p.grad)
|
||||
|
||||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_deepseek_moe_layer()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("world_size", [2, 4])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
def test_deepseek_moe_layer(world_size: int):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_deepseek_moe_layer(2)
|
|
@ -15,6 +15,7 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.testing import parameterize, spawn
|
||||
from colossalai.testing.utils import spawn
|
||||
|
||||
tokens, n_experts = 7, 4
|
||||
|
@ -77,7 +78,23 @@ def check_optimizer_snapshot_equal(snapshot1, snapshot2, param2name, moe_dp_grou
|
|||
raise AssertionError(f"A total of {count} optim states are not equal")
|
||||
|
||||
|
||||
def check_mixtral_moe_layer():
|
||||
@parameterize(
|
||||
"test_config",
|
||||
[
|
||||
[
|
||||
MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
),
|
||||
MixtralForCausalLM,
|
||||
],
|
||||
],
|
||||
)
|
||||
def check_moe_checkpoint(test_config):
|
||||
context = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
|
||||
with context as f:
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
@ -87,17 +104,11 @@ def check_mixtral_moe_layer():
|
|||
broadcast_objects = [None]
|
||||
dist.broadcast_object_list(broadcast_objects, src=0)
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=hidden_size * 2,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=top_k,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
config = test_config[0]
|
||||
model_cls = test_config[1]
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, 100, (2, tokens)).cuda()
|
||||
orig_model = MixtralForCausalLM(config).cuda()
|
||||
orig_model = model_cls(config).cuda()
|
||||
model = deepcopy(orig_model)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
plugin = MoeHybridParallelPlugin(
|
||||
|
@ -120,7 +131,6 @@ def check_mixtral_moe_layer():
|
|||
lambda outputs, inputs: outputs.loss,
|
||||
optimizer,
|
||||
)
|
||||
|
||||
tmpdirname = broadcast_objects[0]
|
||||
model_dir = os.path.join(tmpdirname, "mixtral_model")
|
||||
hf_model_dir = os.path.join(tmpdirname, "mixtral_hf_model")
|
||||
|
@ -129,13 +139,13 @@ def check_mixtral_moe_layer():
|
|||
booster.save_model(model, model_dir, shard=True)
|
||||
dist.barrier()
|
||||
if dist.get_rank() == 0:
|
||||
saved_model = MixtralForCausalLM.from_pretrained(model_dir).cuda()
|
||||
saved_model = model_cls.from_pretrained(model_dir).cuda()
|
||||
check_model_equal(orig_model, saved_model)
|
||||
# check_model_equal(model, saved_model)
|
||||
saved_model.save_pretrained(hf_model_dir)
|
||||
dist.barrier()
|
||||
# check load model
|
||||
new_model = MixtralForCausalLM(config).cuda()
|
||||
new_model = model_cls(config).cuda()
|
||||
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
|
||||
new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
|
||||
booster.load_model(new_model, hf_model_dir)
|
||||
|
@ -163,7 +173,7 @@ def check_mixtral_moe_layer():
|
|||
|
||||
def run_dist(rank: int, world_size: int, port: int):
|
||||
colossalai.launch(rank, world_size, "localhost", port)
|
||||
check_mixtral_moe_layer()
|
||||
check_moe_checkpoint()
|
||||
|
||||
|
||||
# Test EP + ZeRO + PP
|
||||
|
|
Loading…
Reference in New Issue