mirror of https://github.com/InternLM/InternLM
add compatible code for old version
parent
85f4d4af58
commit
3c8fee01b2
|
@ -0,0 +1,170 @@
|
|||
JOB_NAME = "7b_moe_train"
|
||||
DO_ALERT = False
|
||||
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
|
||||
# boto3 Ckpt folder format:
|
||||
# import os
|
||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||
CHECKPOINT_EVERY = 50
|
||||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
|
||||
load_ckpt_folder="local:llm_ckpts/",
|
||||
# 'load_ckpt_info' setting guide:
|
||||
# 1. the 'path' indicate ckpt path,
|
||||
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
empty_cache_and_diag_interval=10,
|
||||
diag_outlier_ratio=1.1,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
fp16=dict(
|
||||
# the initial loss scale, defaults to 2**16
|
||||
initial_scale=2**16,
|
||||
# the minimum loss scale, defaults to None
|
||||
min_scale=1,
|
||||
# the number of steps to increase loss scale when no overflow occurs
|
||||
growth_interval=1000,
|
||||
),
|
||||
# the multiplication factor for increasing loss scale, defaults to 2
|
||||
growth_factor=2,
|
||||
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||
backoff_factor=0.5,
|
||||
# the maximum loss scale, defaults to None
|
||||
max_scale=2**24,
|
||||
# the number of overflows before decreasing loss scale, defaults to 2
|
||||
hysteresis=2,
|
||||
)
|
||||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_param=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
|
||||
loss = dict(
|
||||
label_smoothing=0,
|
||||
)
|
||||
|
||||
adam = dict(
|
||||
lr=1e-4,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.95,
|
||||
adam_beta2_c=0,
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
lr_scheduler = dict(
|
||||
total_steps=data["total_steps"],
|
||||
init_steps=0, # optimizer_warmup_step
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
last_epoch=-1,
|
||||
)
|
||||
|
||||
beta2_scheduler = dict(
|
||||
init_beta2=adam["adam_beta2"],
|
||||
c=adam["adam_beta2_c"],
|
||||
cur_iter=-1,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
num_experts=8,
|
||||
moe_use_residual=False,
|
||||
moe_gate_k=2,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||
so parameters will be divided within the range of dp.
|
||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=-1,
|
||||
tensor=8,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
cudnn_benchmark = False
|
||||
|
||||
monitor = dict(
|
||||
# feishu alert configs
|
||||
alert=dict(
|
||||
enable_feishu_alert=DO_ALERT,
|
||||
feishu_alert_address=None, # feishu webhook to send alert message
|
||||
light_monitor_address=None, # light_monitor address to send heartbeat
|
||||
),
|
||||
)
|
||||
|
||||
model_type = "INTERNLM_MoE"
|
|
@ -107,7 +107,10 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
self._call_hooks("before_forward", data)
|
||||
# moe_losses contains the loss of each layer
|
||||
output, moe_losses = self._call_engine(engine, data)
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
output = self._call_engine(engine, data)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
output, moe_losses = self._call_engine(engine, data)
|
||||
self._call_hooks("after_forward", output)
|
||||
|
||||
self._call_hooks("post_helper_func", output, label)
|
||||
|
@ -116,7 +119,11 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("before_criterion", output, label)
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= scale_loss
|
||||
loss /= scale_loss
|
||||
loss += moe_loss
|
||||
|
@ -199,4 +206,8 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
return outputs, labels, loss, moe_loss
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
return outputs, labels, loss
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
return outputs, labels, loss, moe_loss
|
||||
|
|
|
@ -276,7 +276,10 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
self._call_hooks("before_forward", data)
|
||||
# moe_losses contains the loss of each layer in current stage
|
||||
output_obj, moe_losses = self._call_engine(engine.model, data)
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
output_obj, moe_losses = self._call_engine(engine.model, data)
|
||||
self._call_hooks("after_forward", output_obj)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -292,7 +295,11 @@ class PipelineScheduler(BaseScheduler):
|
|||
accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= self.num_microbatches
|
||||
accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
|
@ -658,9 +665,19 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.load_batch(engine, data_iter)
|
||||
|
||||
if forward_only:
|
||||
return self._forward_only_step(engine, return_loss, return_output_label)
|
||||
output, label, accum_loss, accum_moe_loss = self._forward_only_step(
|
||||
engine, return_loss, return_output_label
|
||||
)
|
||||
else:
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label)
|
||||
output, label, accum_loss, accum_moe_loss = self._forward_backward_step(
|
||||
engine, return_loss, return_output_label
|
||||
)
|
||||
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
return output, label, accum_loss
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
|
||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||
|
@ -799,7 +816,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||
|
||||
self._call_hooks("before_forward", data)
|
||||
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
output_obj = self._call_engine(engine.model[chunk_id], data)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
|
||||
# Convert output_obj to fp32 when last model chunk of last stage
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
|
||||
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
|
||||
|
@ -819,7 +839,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
moe_loss = (
|
||||
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE"
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
moe_loss /= self.num_microbatches
|
||||
|
||||
if self._accum_moe_loss is not None:
|
||||
|
@ -1354,4 +1378,8 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
|
||||
self._clear_state()
|
||||
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
return output, label, accum_loss
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
|
|
@ -205,5 +205,4 @@ class Trainer:
|
|||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, moe_loss).
|
||||
"""
|
||||
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
return output, label, loss, moe_loss
|
||||
return self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
|
|
|
@ -253,8 +253,14 @@ def args_sanity_check():
|
|||
# process the model config
|
||||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
if "num_experts" not in model:
|
||||
model._add_item("num_experts", 1)
|
||||
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
if "num_experts" not in model:
|
||||
model._add_item("num_experts", 1)
|
||||
if "moe_use_residual" not in model:
|
||||
model._add_item("moe_use_residual", False)
|
||||
if "moe_gate_k" not in model:
|
||||
model._add_item("moe_gate_k", 2)
|
||||
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
|
|
|
@ -5,6 +5,7 @@ from .embedding import Embedding1D, RotaryEmbedding
|
|||
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
|
||||
from .metrics import AccPerplex
|
||||
from .modeling_internlm import build_model_with_cfg
|
||||
from .modeling_moe import build_model_with_moe_cfg
|
||||
from .multi_head_attention import MHA
|
||||
from .utils import gather_forward_split_backward
|
||||
|
||||
|
@ -18,4 +19,5 @@ __all__ = [
|
|||
"MHA",
|
||||
"gather_forward_split_backward",
|
||||
"build_model_with_cfg",
|
||||
"build_model_with_moe_cfg",
|
||||
]
|
||||
|
|
|
@ -18,7 +18,6 @@ from internlm.model.linear import (
|
|||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.moe import MoE
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
|
@ -51,17 +50,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
|
||||
infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -84,15 +72,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: int = 1,
|
||||
moe_gate_k: int = 1,
|
||||
moe_capacity_factor: float = 1.0,
|
||||
moe_eval_capacity_factor: float = 1.0,
|
||||
moe_min_capacity: int = 4,
|
||||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
@ -127,77 +106,41 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
for param in self.norm1.parameters():
|
||||
param.is_norm = True
|
||||
for param in self.norm2.parameters():
|
||||
param.is_norm = True
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.moe_gate_k = moe_gate_k
|
||||
self.moe_capacity_factor = moe_capacity_factor
|
||||
self.moe_eval_capacity_factor = moe_eval_capacity_factor
|
||||
self.moe_min_capacity = moe_min_capacity
|
||||
self.moe_noisy_gate_policy = moe_noisy_gate_policy
|
||||
self.moe_drop_tokens = moe_drop_tokens
|
||||
self.moe_use_rts = moe_use_rts
|
||||
self.moe_use_residual = moe_use_residual
|
||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||
if num_experts <= 1: # dense, not MoE
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.mlp.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
else:
|
||||
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
|
||||
self.mlp = MoE(
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
k=moe_gate_k,
|
||||
capacity_factor=moe_capacity_factor,
|
||||
eval_capacity_factor=moe_eval_capacity_factor,
|
||||
min_capacity=moe_min_capacity,
|
||||
noisy_gate_policy=moe_noisy_gate_policy,
|
||||
drop_tokens=moe_drop_tokens,
|
||||
use_rts=moe_use_rts,
|
||||
use_residual=moe_use_residual,
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.mlp.moe_layer.experts.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.mlp.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
self.dropout2 = nn.Dropout(drop_rate)
|
||||
self.use_swiglu = use_swiglu
|
||||
self.use_scaled_init = use_scaled_init
|
||||
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
||||
self.return_residual = False
|
||||
self.reset_parameters() # TODO: check this should be changed when moe is added
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
with torch.no_grad():
|
||||
|
@ -229,7 +172,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
if self.checkpoint and self.training:
|
||||
return activation_checkpoint(
|
||||
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
||||
) # TODO: check whether this will be affected by moe
|
||||
)
|
||||
else:
|
||||
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
||||
|
||||
|
@ -279,14 +222,9 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
|
||||
# MLP.
|
||||
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
if self.num_experts <= 1: # dense mlp output
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
else: # MoE output
|
||||
hidden_states, moe_loss, _ = self.mlp(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states + residual, moe_loss
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class PackedFlashInternLm1D(nn.Module):
|
||||
|
@ -316,17 +254,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
||||
to infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -357,15 +285,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: bool = 1,
|
||||
moe_gate_k: int = 1,
|
||||
moe_capacity_factor: float = 1.0,
|
||||
moe_eval_capacity_factor: float = 1.0,
|
||||
moe_min_capacity: int = 4,
|
||||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -415,15 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
num_experts=num_experts,
|
||||
moe_gate_k=moe_gate_k,
|
||||
moe_capacity_factor=moe_capacity_factor,
|
||||
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
||||
moe_min_capacity=moe_min_capacity,
|
||||
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
||||
moe_drop_tokens=moe_drop_tokens,
|
||||
moe_use_rts=moe_use_rts,
|
||||
moe_use_residual=moe_use_residual,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -450,8 +360,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
# old condition may fail when use shared embedding
|
||||
if gpc.is_pipeline_first_stage():
|
||||
if hasattr(self, "embedding"):
|
||||
hidden_states = self.embedding(input_ids)
|
||||
if self.embed_grad_scale != 1:
|
||||
hidden_states = (
|
||||
|
@ -472,16 +381,14 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
indexes = indexes[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
||||
moe_losses = []
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states, mos_loss = block(
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
indexes=indexes,
|
||||
inference_params=inference_params,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
moe_losses.append(mos_loss)
|
||||
|
||||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
|
@ -490,7 +397,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
|
||||
if not self.parallel_output:
|
||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
return hidden_states, moe_losses
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
|
@ -558,15 +465,6 @@ def build_model_with_cfg(
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: int = 1,
|
||||
moe_gate_k: int = 1,
|
||||
moe_capacity_factor: float = 1.0,
|
||||
moe_eval_capacity_factor: float = 1.0,
|
||||
moe_min_capacity: int = 4,
|
||||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
):
|
||||
"""
|
||||
Build model with config.
|
||||
|
@ -597,17 +495,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
||||
to infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
|
||||
"""
|
||||
|
||||
cfg = dict(
|
||||
|
@ -632,15 +520,6 @@ def build_model_with_cfg(
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
num_experts=num_experts,
|
||||
moe_gate_k=moe_gate_k,
|
||||
moe_capacity_factor=moe_capacity_factor,
|
||||
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
||||
moe_min_capacity=moe_min_capacity,
|
||||
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
||||
moe_drop_tokens=moe_drop_tokens,
|
||||
moe_use_rts=moe_use_rts,
|
||||
moe_use_residual=moe_use_residual,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
|
@ -0,0 +1,646 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
FeedForward,
|
||||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.moe import MoE
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
MODEL_TYPE = "INTERNLM_MoE"
|
||||
|
||||
logger = get_logger(__file__)
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
|
||||
|
||||
class PackedFlashBaseLayer1D(nn.Module):
|
||||
"""
|
||||
1D Packed Flash Base Layer.
|
||||
|
||||
Args:
|
||||
hidden_size (int): The hidden size of model. 768 by default.
|
||||
num_attention_heads (int): The number of attention heads. 12 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
||||
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
|
||||
dtype (torch.dtype): Type of data. torch.float by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
||||
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
|
||||
layer_idx (int): The index of current layer. 0 by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
|
||||
infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0,
|
||||
drop_rate: float = 0.0,
|
||||
max_position_embeddings: int = 2048,
|
||||
dtype: torch.dtype = torch.float,
|
||||
layer_norm_epsilon: float = 1e-6,
|
||||
checkpoint: bool = False,
|
||||
layer_idx: int = 0,
|
||||
use_dynamic_ntk_rope: bool = False,
|
||||
residual_in_fp32: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
norm_type: str = "rmsnorm",
|
||||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: int = 1,
|
||||
moe_gate_k: int = 1,
|
||||
moe_capacity_factor: float = 1.0,
|
||||
moe_eval_capacity_factor: float = 1.0,
|
||||
moe_min_capacity: int = 4,
|
||||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
||||
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
||||
self.layer_idx = layer_idx
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
self.mixer = MHA(
|
||||
embed_dim=hidden_size,
|
||||
num_heads=num_attention_heads,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
dropout=attn_drop_rate,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
softmax_scale=1 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
layer_idx=layer_idx,
|
||||
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
|
||||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=use_flash_attn,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.dropout1 = nn.Dropout(drop_rate)
|
||||
if norm_type == "rmsnorm":
|
||||
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
for param in self.norm1.parameters():
|
||||
param.is_norm = True
|
||||
for param in self.norm2.parameters():
|
||||
param.is_norm = True
|
||||
|
||||
self.num_experts = num_experts
|
||||
self.moe_gate_k = moe_gate_k
|
||||
self.moe_capacity_factor = moe_capacity_factor
|
||||
self.moe_eval_capacity_factor = moe_eval_capacity_factor
|
||||
self.moe_min_capacity = moe_min_capacity
|
||||
self.moe_noisy_gate_policy = moe_noisy_gate_policy
|
||||
self.moe_drop_tokens = moe_drop_tokens
|
||||
self.moe_use_rts = moe_use_rts
|
||||
self.moe_use_residual = moe_use_residual
|
||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||
if num_experts <= 1: # dense, not MoE
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.mlp.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
else:
|
||||
# replace mlp by MoE module. The expert in MoE is a FeedForward module.
|
||||
self.mlp = MoE(
|
||||
hidden_size=hidden_size,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
k=moe_gate_k,
|
||||
capacity_factor=moe_capacity_factor,
|
||||
eval_capacity_factor=moe_eval_capacity_factor,
|
||||
min_capacity=moe_min_capacity,
|
||||
noisy_gate_policy=moe_noisy_gate_policy,
|
||||
drop_tokens=moe_drop_tokens,
|
||||
use_rts=moe_use_rts,
|
||||
use_residual=moe_use_residual,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.mlp.moe_layer.experts.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
self.dropout2 = nn.Dropout(drop_rate)
|
||||
self.use_swiglu = use_swiglu
|
||||
self.use_scaled_init = use_scaled_init
|
||||
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
||||
self.return_residual = False
|
||||
self.reset_parameters() # TODO: check this should be changed when moe is added
|
||||
|
||||
def reset_parameters(self):
|
||||
with torch.no_grad():
|
||||
for name, param in self.mixer.named_parameters():
|
||||
if param.ndim == 1:
|
||||
param.data.zero_()
|
||||
elif "Wqkv" in name:
|
||||
normal_(std=0.006)(param.data)
|
||||
elif self.use_scaled_init:
|
||||
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
||||
else:
|
||||
normal_(std=0.0015)(param.data)
|
||||
|
||||
for name, param in self.mlp.named_parameters():
|
||||
if param.ndim == 1 and "bias" in name:
|
||||
param.data.zero_()
|
||||
elif self.use_swiglu:
|
||||
if self.use_scaled_init and "w2" in name:
|
||||
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
||||
else:
|
||||
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
|
||||
else:
|
||||
if self.use_scaled_init and "fc1" not in name:
|
||||
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
||||
else:
|
||||
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
||||
|
||||
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
||||
if self.checkpoint and self.training:
|
||||
return activation_checkpoint(
|
||||
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
||||
) # TODO: check whether this will be affected by moe
|
||||
else:
|
||||
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
||||
|
||||
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: hidden_states = Attn/MLP(LN(residual))
|
||||
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
|
||||
indexes: the length of index is same as hidden states, which stand for the current position
|
||||
"""
|
||||
mixer_kwargs = {
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"max_seqlen": max_seqlen,
|
||||
"indexes": indexes,
|
||||
"inference_params": inference_params,
|
||||
}
|
||||
|
||||
def _dropout_and_norm_attn(_hidden_states):
|
||||
_dropped = self.dropout1(_hidden_states)
|
||||
_residual = _dropped
|
||||
_hidden_states = self.norm1(_residual.float())
|
||||
return _residual, _hidden_states
|
||||
|
||||
if self.dropout_selective_checkpoint:
|
||||
residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
|
||||
else:
|
||||
residual, hidden_states = _dropout_and_norm_attn(hidden_states)
|
||||
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||
|
||||
def _dropout_and_norm_ffn(_residual, _hidden_states):
|
||||
_dropped = self.dropout2(_hidden_states)
|
||||
_residual = (_dropped + _residual) if _residual is not None else _dropped
|
||||
_hidden_states = self.norm2(_residual.float())
|
||||
return _residual, _hidden_states
|
||||
|
||||
if self.dropout_selective_checkpoint:
|
||||
residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
|
||||
else:
|
||||
residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
|
||||
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
|
||||
# MLP.
|
||||
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
if self.num_experts <= 1: # dense mlp output
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
else: # MoE output
|
||||
hidden_states, moe_loss, _ = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states + residual, moe_loss
|
||||
|
||||
|
||||
class PackedFlashInternLm1D(nn.Module):
|
||||
"""
|
||||
1D Packed Flash InternLm.
|
||||
|
||||
Args:
|
||||
num_layers (int): The number of layer. 12 by default.
|
||||
hidden_size (int): The size of hidden state. 768 by default.
|
||||
num_attention_heads (int): The number of attention head. 12 by default.
|
||||
vocab_size (int): The size of vocabulary. 50304 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
||||
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
|
||||
dtype (torch.dtype): The type of data. torch.float by default.
|
||||
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
|
||||
of layers. 0.0 by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
|
||||
first (bool): Whether input embedding layer or not. False by default.
|
||||
last (bool): Whether output embedding layer or not. False by default.
|
||||
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
||||
True by default.
|
||||
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
||||
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
||||
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
||||
to infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
max_position_embeddings: int = 2048,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: float = 0.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden: bool = False,
|
||||
embed_grad_scale: float = 0.1,
|
||||
parallel_output: bool = True,
|
||||
start_layer_idx: int = 0,
|
||||
use_dynamic_ntk_rope: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
residual_in_fp32: bool = False,
|
||||
norm_type: str = "rmsnorm",
|
||||
is_reward: bool = False,
|
||||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: bool = 1,
|
||||
moe_gate_k: int = 1,
|
||||
moe_capacity_factor: float = 1.0,
|
||||
moe_eval_capacity_factor: float = 1.0,
|
||||
moe_min_capacity: int = 4,
|
||||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
checkpoint_layer_num = int(num_layers * checkpoint)
|
||||
|
||||
if is_reward:
|
||||
head_cls = RewardModelLinear
|
||||
else:
|
||||
head_cls = ScaleColumnParallelLinear
|
||||
if first:
|
||||
if embed_split_hidden:
|
||||
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
||||
else:
|
||||
self.embedding = ParallelGPT2Embeddings(
|
||||
embed_dim=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.embedding.named_parameters():
|
||||
normal_(std=0.0052)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.embed_grad_scale = embed_grad_scale
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PackedFlashBaseLayer1D(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
dtype=dtype,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
checkpoint=lid < checkpoint_layer_num,
|
||||
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
|
||||
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
device=device,
|
||||
norm_type=norm_type,
|
||||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
num_experts=num_experts,
|
||||
moe_gate_k=moe_gate_k,
|
||||
moe_capacity_factor=moe_capacity_factor,
|
||||
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
||||
moe_min_capacity=moe_min_capacity,
|
||||
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
||||
moe_drop_tokens=moe_drop_tokens,
|
||||
moe_use_rts=moe_use_rts,
|
||||
moe_use_residual=moe_use_residual,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
if norm_type == "rmsnorm":
|
||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.head = head_cls(
|
||||
in_features=hidden_size,
|
||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
)
|
||||
for _, param in self.head.named_parameters():
|
||||
normal_(std=0.0052)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
# attention_mask: compute attention on the places where the value is 1
|
||||
# old condition may fail when use shared embedding
|
||||
if gpc.is_pipeline_first_stage():
|
||||
hidden_states = self.embedding(input_ids)
|
||||
if self.embed_grad_scale != 1:
|
||||
hidden_states = (
|
||||
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
|
||||
)
|
||||
if isinstance(cu_seqlens, list):
|
||||
assert len(cu_seqlens) == 1
|
||||
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
|
||||
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
|
||||
# the batch dimension with a size of 1 should be directly squeezed off.
|
||||
|
||||
if indexes is not None:
|
||||
assert len(indexes) == 1
|
||||
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
||||
indexes = indexes[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
||||
moe_losses = []
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states, mos_loss = block(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
indexes=indexes,
|
||||
inference_params=inference_params,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
moe_losses.append(mos_loss)
|
||||
|
||||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
if hasattr(self, "head"):
|
||||
hidden_states = self.head(hidden_states)
|
||||
|
||||
if not self.parallel_output:
|
||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
return hidden_states, moe_losses
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
"""
|
||||
build generic model 1d
|
||||
|
||||
Args:
|
||||
num_layers (int): The number of layer.
|
||||
num_chunks (int): The number of partitions in pipeline parallel.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
|
||||
|
||||
"""
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"The layer sharding is {all_parts}.")
|
||||
|
||||
models = []
|
||||
|
||||
for start, end in parts:
|
||||
kwargs["num_layers"] = end - start
|
||||
kwargs["first"] = start == 0
|
||||
# If there is no content in the final layer, assign the last layer.
|
||||
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
|
||||
kwargs["device"] = device
|
||||
kwargs["start_layer_idx"] = start
|
||||
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
|
||||
|
||||
models.append(chunk)
|
||||
torch.distributed.barrier()
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
|
||||
def build_model_with_moe_cfg(
|
||||
num_chunks=1,
|
||||
checkpoint=0.0,
|
||||
dtype=torch.float,
|
||||
embed_split_hidden=False,
|
||||
num_layers=48,
|
||||
hidden_size=2048,
|
||||
vocab_size=50304,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
num_attention_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
mlp_ratio=4.0,
|
||||
residual_in_fp32=False,
|
||||
use_dynamic_ntk_rope=False,
|
||||
norm_type="rmsnorm",
|
||||
drop_rate=0,
|
||||
attn_drop_rate=0,
|
||||
apply_post_layer_norm=False, # pylint: disable=W0613
|
||||
layer_norm_epsilon=1e-5,
|
||||
is_reward=False,
|
||||
dropout_selective_checkpoint=True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: int = 1,
|
||||
moe_gate_k: int = 1,
|
||||
moe_capacity_factor: float = 1.0,
|
||||
moe_eval_capacity_factor: float = 1.0,
|
||||
moe_min_capacity: int = 4,
|
||||
moe_noisy_gate_policy: str = None,
|
||||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
):
|
||||
"""
|
||||
Build model with config.
|
||||
|
||||
Args:
|
||||
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
|
||||
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
|
||||
dtype (torch.dtype): The type of data. torch.float by default.
|
||||
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
||||
False by default.
|
||||
num_layers (int): The number of layer. 48 by default.
|
||||
hidden_size (int): The size of hidden state. 2048 by default.
|
||||
vocab_size (int): The size of vocabulary. 50304 by default.
|
||||
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
||||
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
||||
num_attention_heads (int): The number of attention head. 32 by default.
|
||||
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
|
||||
because this parameter requires inconsistent data types to be passed between pipelines,
|
||||
which requires significant modifications to internlm.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
drop_rate (float): The dropout rate of input hidden state. 0 by default.
|
||||
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
||||
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
|
||||
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
||||
is_reward (bool): Whether to use reward model. False by default.
|
||||
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
||||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
|
||||
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
|
||||
to infinite capacity).
|
||||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
"""
|
||||
|
||||
cfg = dict(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
vocab_size=vocab_size,
|
||||
embed_grad_scale=embed_grad_scale,
|
||||
parallel_output=parallel_output,
|
||||
mlp_ratio=mlp_ratio,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
|
||||
norm_type=norm_type,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
is_reward=is_reward,
|
||||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
num_experts=num_experts,
|
||||
moe_gate_k=moe_gate_k,
|
||||
moe_capacity_factor=moe_capacity_factor,
|
||||
moe_eval_capacity_factor=moe_eval_capacity_factor,
|
||||
moe_min_capacity=moe_min_capacity,
|
||||
moe_noisy_gate_policy=moe_noisy_gate_policy,
|
||||
moe_drop_tokens=moe_drop_tokens,
|
||||
moe_use_rts=moe_use_rts,
|
||||
moe_use_residual=moe_use_residual,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
|
@ -111,7 +111,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
|
||||
adam_cfg = gpc.config.adam
|
||||
# split the moe parameters into different groups
|
||||
if gpc.config.model.num_experts > 1:
|
||||
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
|
||||
params = create_param_groups(model, adam_cfg.weight_decay)
|
||||
else:
|
||||
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
|
||||
|
@ -435,8 +435,7 @@ def record_current_batch_training_metrics(
|
|||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item() - moe_loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
|
||||
"tgs (tokens/gpu/second)": tgs_origin,
|
||||
"tgs/last_tgs_1": last_tgs_1,
|
||||
"tgs/tgs_all": tgs_all,
|
||||
|
@ -448,6 +447,8 @@ def record_current_batch_training_metrics(
|
|||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
}
|
||||
if moe_loss is not None:
|
||||
infos["moe_loss"] = moe_loss.item()
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
|
||||
|
@ -481,13 +482,14 @@ def record_current_batch_training_metrics(
|
|||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"loss": loss.item() - moe_loss.item(),
|
||||
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": last_tgs_1,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
}
|
||||
panel_metrics["moe_loss"] = moe_loss.item()
|
||||
for norm_key, norm_value in grad_norm.items():
|
||||
panel_metrics[norm_key] = norm_value
|
||||
|
||||
|
|
|
@ -97,6 +97,7 @@ def evaluate_on_val_dls(
|
|||
disable=not verbose,
|
||||
leave=False,
|
||||
):
|
||||
moe_loss = None
|
||||
with torch.inference_mode():
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
|
@ -112,9 +113,15 @@ def evaluate_on_val_dls(
|
|||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
# Compatible for old code
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
|
@ -126,11 +133,16 @@ def evaluate_on_val_dls(
|
|||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
val_loss += loss.item() - moe_loss.item()
|
||||
val_loss += loss.item() - moe_loss.item() if moe_loss is not None else loss.item()
|
||||
|
||||
assert val_idx != -1
|
||||
dist.barrier()
|
||||
|
|
|
@ -186,11 +186,20 @@ def train(
|
|||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
# Compatible for old code
|
||||
moe_loss = None
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
elif gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
if gpc.is_rank_for_log():
|
||||
assert loss is not None and not math.isnan(loss.item())
|
||||
global cur_loss_list
|
||||
cur_loss_list.append(loss.item() - moe_loss.item())
|
||||
cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item()))
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
|
|
@ -59,6 +59,7 @@ init_config = Config(
|
|||
def init_naive_model():
|
||||
# let MODEL_INITIALIZER to work
|
||||
import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import
|
||||
import internlm.model.modeling_moe # noqa # pylint: disable=unused-import
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
|
|
21
train.py
21
train.py
|
@ -219,12 +219,21 @@ def main(args):
|
|||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
moe_loss = None
|
||||
if gpc.config.get("model_type") == "INTERNLM":
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
if gpc.config.get("model_type") == "INTERNLM_MoE":
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
|
Loading…
Reference in New Issue