add compatible code for old version

pull/182/head
Wenwen Qu 2023-09-26 11:51:34 +08:00
parent 85f4d4af58
commit 3c8fee01b2
13 changed files with 962 additions and 188 deletions

170
configs/7B_MoE8_sft.py Normal file
View File

@ -0,0 +1,170 @@
JOB_NAME = "7b_moe_train"
DO_ALERT = False
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported.
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
diag_outlier_ratio=1.1,
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_param=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)
loss = dict(
label_smoothing=0,
)
adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)
lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=8,
moe_use_residual=False,
moe_gate_k=2,
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=-1,
tensor=8,
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
)
cudnn_deterministic = False
cudnn_benchmark = False
monitor = dict(
# feishu alert configs
alert=dict(
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
),
)
model_type = "INTERNLM_MoE"

View File

@ -107,7 +107,10 @@ class NonPipelineScheduler(BaseScheduler):
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer
output, moe_losses = self._call_engine(engine, data)
if gpc.config.get("model_type") == "INTERNLM":
output = self._call_engine(engine, data)
if gpc.config.get("model_type") == "INTERNLM_MoE":
output, moe_losses = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
self._call_hooks("post_helper_func", output, label)
@ -116,7 +119,11 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("before_criterion", output, label)
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if gpc.config.get("model_type") == "INTERNLM_MoE"
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= scale_loss
loss /= scale_loss
loss += moe_loss
@ -199,4 +206,8 @@ class NonPipelineScheduler(BaseScheduler):
if not return_output_label:
outputs, labels = None, None
return outputs, labels, loss, moe_loss
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
return outputs, labels, loss
if gpc.config.get("model_type") == "INTERNLM_MoE":
return outputs, labels, loss, moe_loss

View File

@ -276,7 +276,10 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("before_forward", data)
# moe_losses contains the loss of each layer in current stage
output_obj, moe_losses = self._call_engine(engine.model, data)
if gpc.config.get("model_type") == "INTERNLM":
output_obj = self._call_engine(engine.model, data)
if gpc.config.get("model_type") == "INTERNLM_MoE":
output_obj, moe_losses = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj)
if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -292,7 +295,11 @@ class PipelineScheduler(BaseScheduler):
accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if gpc.config.get("model_type") == "INTERNLM_MoE"
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
@ -658,9 +665,19 @@ class PipelineScheduler(BaseScheduler):
self.load_batch(engine, data_iter)
if forward_only:
return self._forward_only_step(engine, return_loss, return_output_label)
output, label, accum_loss, accum_moe_loss = self._forward_only_step(
engine, return_loss, return_output_label
)
else:
return self._forward_backward_step(engine, return_loss, return_output_label)
output, label, accum_loss, accum_moe_loss = self._forward_backward_step(
engine, return_loss, return_output_label
)
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
return output, label, accum_loss
if gpc.config.get("model_type") == "INTERNLM_MoE":
return output, label, accum_loss, accum_moe_loss
class InterleavedPipelineScheduler(PipelineScheduler):
@ -799,7 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
if gpc.config.get("model_type") == "INTERNLM":
output_obj = self._call_engine(engine.model[chunk_id], data)
if gpc.config.get("model_type") == "INTERNLM_MoE":
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
@ -819,7 +839,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if gpc.config.get("model_type") == "INTERNLM_MoE"
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= self.num_microbatches
if self._accum_moe_loss is not None:
@ -1354,4 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._clear_state()
return output, label, accum_loss, accum_moe_loss
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
return output, label, accum_loss
if gpc.config.get("model_type") == "INTERNLM_MoE":
return output, label, accum_loss, accum_moe_loss

View File

@ -205,5 +205,4 @@ class Trainer:
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
"""
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss, moe_loss
return self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)

View File

@ -253,8 +253,14 @@ def args_sanity_check():
# process the model config
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)
if "num_experts" not in model:
model._add_item("num_experts", 1)
if gpc.config.get("model_type") == "INTERNLM_MoE":
if "num_experts" not in model:
model._add_item("num_experts", 1)
if "moe_use_residual" not in model:
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:

View File

@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
from .metrics import AccPerplex
from .modeling_internlm import build_model_with_cfg
from .modeling_moe import build_model_with_moe_cfg
from .multi_head_attention import MHA
from .utils import gather_forward_split_backward
@ -18,4 +19,5 @@ __all__ = [
"MHA",
"gather_forward_split_backward",
"build_model_with_cfg",
"build_model_with_moe_cfg",
]

View File

@ -18,7 +18,6 @@ from internlm.model.linear import (
RewardModelLinear,
ScaleColumnParallelLinear,
)
from internlm.model.moe import MoE
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform
@ -51,17 +50,6 @@ class PackedFlashBaseLayer1D(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
def __init__(
@ -84,15 +72,6 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
self.checkpoint = checkpoint
@ -127,77 +106,41 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
for param in self.norm1.parameters():
param.is_norm = True
for param in self.norm2.parameters():
param.is_norm = True
self.num_experts = num_experts
self.moe_gate_k = moe_gate_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_eval_capacity_factor = moe_eval_capacity_factor
self.moe_min_capacity = moe_min_capacity
self.moe_noisy_gate_policy = moe_noisy_gate_policy
self.moe_drop_tokens = moe_drop_tokens
self.moe_use_rts = moe_use_rts
self.moe_use_residual = moe_use_residual
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
if num_experts <= 1: # dense, not MoE
if use_swiglu:
self.mlp = FeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
else:
self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
else:
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
self.mlp = MoE(
hidden_size=hidden_size,
num_experts=num_experts,
ep_size=ep_size,
k=moe_gate_k,
capacity_factor=moe_capacity_factor,
eval_capacity_factor=moe_eval_capacity_factor,
min_capacity=moe_min_capacity,
noisy_gate_policy=moe_noisy_gate_policy,
drop_tokens=moe_drop_tokens,
use_rts=moe_use_rts,
use_residual=moe_use_residual,
if use_swiglu:
self.mlp = FeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
for _, param in self.mlp.moe_layer.experts.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
else:
self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.parallel.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
self.return_residual = False
self.reset_parameters() # TODO: check this should be changed when moe is added
self.reset_parameters()
def reset_parameters(self):
with torch.no_grad():
@ -229,7 +172,7 @@ class PackedFlashBaseLayer1D(nn.Module):
if self.checkpoint and self.training:
return activation_checkpoint(
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
) # TODO: check whether this will be affected by moe
)
else:
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
@ -279,14 +222,9 @@ class PackedFlashBaseLayer1D(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
# MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1: # dense mlp output
hidden_states = self.mlp(hidden_states)
else: # MoE output
hidden_states, moe_loss, _ = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states)
return hidden_states + residual, moe_loss
return hidden_states + residual
class PackedFlashInternLm1D(nn.Module):
@ -316,17 +254,7 @@ class PackedFlashInternLm1D(nn.Module):
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
def __init__(
@ -357,15 +285,6 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: bool = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
@ -415,15 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
for lid in range(num_layers)
]
@ -450,8 +360,7 @@ class PackedFlashInternLm1D(nn.Module):
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
# old condition may fail when use shared embedding
if gpc.is_pipeline_first_stage():
if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids)
if self.embed_grad_scale != 1:
hidden_states = (
@ -472,16 +381,14 @@ class PackedFlashInternLm1D(nn.Module):
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
moe_losses = []
for _, block in enumerate(self.blocks):
hidden_states, mos_loss = block(
hidden_states = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=inference_params,
max_seqlen=max_seqlen,
)
moe_losses.append(mos_loss)
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
@ -490,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module):
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
return hidden_states, moe_losses
return hidden_states
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
@ -558,15 +465,6 @@ def build_model_with_cfg(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
"""
Build model with config.
@ -597,17 +495,7 @@ def build_model_with_cfg(
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
cfg = dict(
@ -632,15 +520,6 @@ def build_model_with_cfg(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -0,0 +1,646 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Optional
import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
RewardModelLinear,
ScaleColumnParallelLinear,
)
from internlm.model.moe import MoE
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER
MODEL_TYPE = "INTERNLM_MoE"
logger = get_logger(__file__)
RMSNorm = try_import_RMSNorm()
class PackedFlashBaseLayer1D(nn.Module):
"""
1D Packed Flash Base Layer.
Args:
hidden_size (int): The hidden size of model. 768 by default.
num_attention_heads (int): The number of attention heads. 12 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
dtype (torch.dtype): Type of data. torch.float by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
layer_idx (int): The index of current layer. 0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
def __init__(
self,
hidden_size: int = 768,
num_attention_heads: int = 12,
mlp_ratio: int = 4,
attn_drop_rate: float = 0,
drop_rate: float = 0.0,
max_position_embeddings: int = 2048,
dtype: torch.dtype = torch.float,
layer_norm_epsilon: float = 1e-6,
checkpoint: bool = False,
layer_idx: int = 0,
use_dynamic_ntk_rope: bool = False,
residual_in_fp32: bool = False,
device: Optional[torch.device] = None,
norm_type: str = "rmsnorm",
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
self.checkpoint = checkpoint
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
self.layer_idx = layer_idx
self.use_flash_attn = use_flash_attn
head_dim = hidden_size // num_attention_heads
self.mixer = MHA(
embed_dim=hidden_size,
num_heads=num_attention_heads,
process_group=gpc.get_group(ParallelMode.TENSOR),
dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim),
causal=True,
layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn,
device=device,
dtype=dtype,
)
self.dropout1 = nn.Dropout(drop_rate)
if norm_type == "rmsnorm":
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
for param in self.norm1.parameters():
param.is_norm = True
for param in self.norm2.parameters():
param.is_norm = True
self.num_experts = num_experts
self.moe_gate_k = moe_gate_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_eval_capacity_factor = moe_eval_capacity_factor
self.moe_min_capacity = moe_min_capacity
self.moe_noisy_gate_policy = moe_noisy_gate_policy
self.moe_drop_tokens = moe_drop_tokens
self.moe_use_rts = moe_use_rts
self.moe_use_residual = moe_use_residual
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
if num_experts <= 1: # dense, not MoE
if use_swiglu:
self.mlp = FeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
else:
self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
for _, param in self.mlp.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
else:
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
self.mlp = MoE(
hidden_size=hidden_size,
num_experts=num_experts,
ep_size=ep_size,
k=moe_gate_k,
capacity_factor=moe_capacity_factor,
eval_capacity_factor=moe_eval_capacity_factor,
min_capacity=moe_min_capacity,
noisy_gate_policy=moe_noisy_gate_policy,
drop_tokens=moe_drop_tokens,
use_rts=moe_use_rts,
use_residual=moe_use_residual,
device=device,
dtype=dtype,
)
for _, param in self.mlp.moe_layer.experts.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
self.return_residual = False
self.reset_parameters() # TODO: check this should be changed when moe is added
def reset_parameters(self):
with torch.no_grad():
for name, param in self.mixer.named_parameters():
if param.ndim == 1:
param.data.zero_()
elif "Wqkv" in name:
normal_(std=0.006)(param.data)
elif self.use_scaled_init:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
normal_(std=0.0015)(param.data)
for name, param in self.mlp.named_parameters():
if param.ndim == 1 and "bias" in name:
param.data.zero_()
elif self.use_swiglu:
if self.use_scaled_init and "w2" in name:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
else:
if self.use_scaled_init and "fc1" not in name:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
if self.checkpoint and self.training:
return activation_checkpoint(
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
) # TODO: check whether this will be affected by moe
else:
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Attn/MLP(LN(residual))
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
mixer_kwargs = {
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
"indexes": indexes,
"inference_params": inference_params,
}
def _dropout_and_norm_attn(_hidden_states):
_dropped = self.dropout1(_hidden_states)
_residual = _dropped
_hidden_states = self.norm1(_residual.float())
return _residual, _hidden_states
if self.dropout_selective_checkpoint:
residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
else:
residual, hidden_states = _dropout_and_norm_attn(hidden_states)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
_residual = (_dropped + _residual) if _residual is not None else _dropped
_hidden_states = self.norm2(_residual.float())
return _residual, _hidden_states
if self.dropout_selective_checkpoint:
residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
else:
residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
# MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1: # dense mlp output
hidden_states = self.mlp(hidden_states)
else: # MoE output
hidden_states, moe_loss, _ = self.mlp(hidden_states)
return hidden_states + residual, moe_loss
class PackedFlashInternLm1D(nn.Module):
"""
1D Packed Flash InternLm.
Args:
num_layers (int): The number of layer. 12 by default.
hidden_size (int): The size of hidden state. 768 by default.
num_attention_heads (int): The number of attention head. 12 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
mlp_ratio (int): The ratio of MLP layers. 4 by default.
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
dtype (torch.dtype): The type of data. torch.float by default.
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
of layers. 0.0 by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
first (bool): Whether input embedding layer or not. False by default.
last (bool): Whether output embedding layer or not. False by default.
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
True by default.
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
def __init__(
self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
mlp_ratio: int = 4.0,
attn_drop_rate: float = 0.0,
drop_rate: float = 0.0,
max_position_embeddings: int = 2048,
dtype: torch.dtype = torch.float,
checkpoint: float = 0.0,
layer_norm_epsilon: float = 1e-5,
first: bool = False,
last: bool = False,
embed_split_hidden: bool = False,
embed_grad_scale: float = 0.1,
parallel_output: bool = True,
start_layer_idx: int = 0,
use_dynamic_ntk_rope: bool = False,
device: Optional[torch.device] = None,
residual_in_fp32: bool = False,
norm_type: str = "rmsnorm",
is_reward: bool = False,
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: bool = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
checkpoint_layer_num = int(num_layers * checkpoint)
if is_reward:
head_cls = RewardModelLinear
else:
head_cls = ScaleColumnParallelLinear
if first:
if embed_split_hidden:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
else:
self.embedding = ParallelGPT2Embeddings(
embed_dim=hidden_size,
vocab_size=vocab_size,
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
for _, param in self.embedding.named_parameters():
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.embed_grad_scale = embed_grad_scale
self.blocks = nn.ModuleList(
[
PackedFlashBaseLayer1D(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
mlp_ratio=mlp_ratio,
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
max_position_embeddings=max_position_embeddings,
dtype=dtype,
layer_norm_epsilon=layer_norm_epsilon,
checkpoint=lid < checkpoint_layer_num,
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
residual_in_fp32=residual_in_fp32,
device=device,
norm_type=norm_type,
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
for lid in range(num_layers)
]
)
if last:
if norm_type == "rmsnorm":
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
weight_scale=embed_grad_scale,
)
for _, param in self.head.named_parameters():
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1
# old condition may fail when use shared embedding
if gpc.is_pipeline_first_stage():
hidden_states = self.embedding(input_ids)
if self.embed_grad_scale != 1:
hidden_states = (
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
)
if isinstance(cu_seqlens, list):
assert len(cu_seqlens) == 1
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.squeeze(0)
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed init indicated a packed state
# the batch dimension with a size of 1 should be directly squeezed off.
if indexes is not None:
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
moe_losses = []
for _, block in enumerate(self.blocks):
hidden_states, mos_loss = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=inference_params,
max_seqlen=max_seqlen,
)
moe_losses.append(mos_loss)
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
if hasattr(self, "head"):
hidden_states = self.head(hidden_states)
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
return hidden_states, moe_losses
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
"""
build generic model 1d
Args:
num_layers (int): The number of layer.
num_chunks (int): The number of partitions in pipeline parallel.
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
"""
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
if gpc.is_rank_for_log():
logger.info(f"The layer sharding is {all_parts}.")
models = []
for start, end in parts:
kwargs["num_layers"] = end - start
kwargs["first"] = start == 0
# If there is no content in the final layer, assign the last layer.
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
kwargs["device"] = device
kwargs["start_layer_idx"] = start
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
models.append(chunk)
torch.distributed.barrier()
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
def build_model_with_moe_cfg(
num_chunks=1,
checkpoint=0.0,
dtype=torch.float,
embed_split_hidden=False,
num_layers=48,
hidden_size=2048,
vocab_size=50304,
embed_grad_scale=1,
parallel_output=True,
num_attention_heads=32,
max_position_embeddings=2048,
mlp_ratio=4.0,
residual_in_fp32=False,
use_dynamic_ntk_rope=False,
norm_type="rmsnorm",
drop_rate=0,
attn_drop_rate=0,
apply_post_layer_norm=False, # pylint: disable=W0613
layer_norm_epsilon=1e-5,
is_reward=False,
dropout_selective_checkpoint=True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
"""
Build model with config.
Args:
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
dtype (torch.dtype): The type of data. torch.float by default.
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
False by default.
num_layers (int): The number of layer. 48 by default.
hidden_size (int): The size of hidden state. 2048 by default.
vocab_size (int): The size of vocabulary. 50304 by default.
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
num_attention_heads (int): The number of attention head. 32 by default.
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
because this parameter requires inconsistent data types to be passed between pipelines,
which requires significant modifications to internlm.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
drop_rate (float): The dropout rate of input hidden state. 0 by default.
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
is_reward (bool): Whether to use reward model. False by default.
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
"""
cfg = dict(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden,
vocab_size=vocab_size,
embed_grad_scale=embed_grad_scale,
parallel_output=parallel_output,
mlp_ratio=mlp_ratio,
residual_in_fp32=residual_in_fp32,
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
norm_type=norm_type,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
layer_norm_epsilon=layer_norm_epsilon,
is_reward=is_reward,
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -111,7 +111,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
adam_cfg = gpc.config.adam
# split the moe parameters into different groups
if gpc.config.model.num_experts > 1:
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
params = create_param_groups(model, adam_cfg.weight_decay)
else:
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
@ -435,8 +435,7 @@ def record_current_batch_training_metrics(
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item() - moe_loss.item(),
"moe_loss": moe_loss.item(),
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
"tgs (tokens/gpu/second)": tgs_origin,
"tgs/last_tgs_1": last_tgs_1,
"tgs/tgs_all": tgs_all,
@ -448,6 +447,8 @@ def record_current_batch_training_metrics(
"loss_scale": scaler,
"grad_norm": grad_norm,
}
if moe_loss is not None:
infos["moe_loss"] = moe_loss.item()
infos["micro_num"] = len(batch[1])
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
@ -481,13 +482,14 @@ def record_current_batch_training_metrics(
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item() - moe_loss.item(),
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
"flops": tflops,
"tgs": last_tgs_1,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
}
panel_metrics["moe_loss"] = moe_loss.item()
for norm_key, norm_value in grad_norm.items():
panel_metrics[norm_key] = norm_value

View File

@ -97,6 +97,7 @@ def evaluate_on_val_dls(
disable=not verbose,
leave=False,
):
moe_loss = None
with torch.inference_mode():
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
@ -112,9 +113,15 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
# Compatible for old code
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
elif gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
@ -126,11 +133,16 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
elif gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item() - moe_loss.item()
val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item()
assert val_idx != -1
dist.barrier()

View File

@ -186,11 +186,20 @@ def train(
# do forward and backward
timer("fwd-bwd").start()
_, _, loss, moe_loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
# Compatible for old code
moe_loss = None
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
elif gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
if gpc.is_rank_for_log():
assert loss is not None and not math.isnan(loss.item())
global cur_loss_list
cur_loss_list.append(loss.item() - moe_loss.item())
cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item()))
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)

View File

@ -59,6 +59,7 @@ init_config = Config(
def init_naive_model():
# let MODEL_INITIALIZER to work
import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import
import internlm.model.modeling_moe # noqa # pylint: disable=unused-import
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.registry import MODEL_INITIALIZER

View File

@ -219,12 +219,21 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
moe_loss = None
if gpc.config.get("model_type") == "INTERNLM":
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
if gpc.config.get("model_type") == "INTERNLM_MoE":
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)