add no-interleaved & no-overlapped moe pp support

pull/182/head
zhanglei 2023-08-14 11:10:37 +08:00
parent d8e5397159
commit 1accc9f08d
2 changed files with 179 additions and 15 deletions

152
configs/moe_cfg.py Normal file
View File

@ -0,0 +1,152 @@
JOB_NAME = "7b_train"
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 16
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=False,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)
loss = dict(
label_smoothing=0,
moe_loss_coeff=0.01,
)
adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)
lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
model = dict(
checkpoint=False,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
num_experts=4,
moe_use_residual=True,
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
# zero1=8,
pipeline=dict(size=4, interleaved_overlap=False),
tensor=dict(size=2),
)
cudnn_deterministic = False
cudnn_benchmark = False

View File

@ -239,7 +239,7 @@ class PipelineScheduler(BaseScheduler):
"""
return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0):
"""
Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
@ -259,7 +259,7 @@ class PipelineScheduler(BaseScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model, data)
output_obj, moe_losses = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj)
if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -272,12 +272,14 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("after_criterion", loss)
loss_reduced = loss / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
accum_loss.add_(loss_reduced)
output_obj = loss_reduced
return output_obj
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
return output_obj, moe_loss
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
"""
Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
@ -311,6 +313,9 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("before_backward", output_obj, output_obj_grad)
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
if moe_loss is not None:
moe_loss.backward(retain_graph=True)
if output_obj_grad is None:
engine.backward(output_obj)
else:
@ -329,7 +334,7 @@ class PipelineScheduler(BaseScheduler):
return input_obj_grad
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
"""
This function performs forward only computation process. The scheduling of microbatches is similar to the
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
@ -376,12 +381,13 @@ class PipelineScheduler(BaseScheduler):
input_obj = None
# Perform forward computation
output_obj = self._forward_step(
output_obj, _ = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
moe_loss_coeff=moe_loss_coeff,
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
@ -395,7 +401,7 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
"""
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
It consists of three stages: warmup, 1F1B, and cooldown.
@ -441,6 +447,7 @@ class PipelineScheduler(BaseScheduler):
# Input, output tensors only need to be saved when doing backward passes
input_objs = []
output_objs = []
moe_losses = []
return_tensors = []
accum_loss = (
torch.zeros(1, device=get_current_device())
@ -468,12 +475,13 @@ class PipelineScheduler(BaseScheduler):
input_obj = None
# Perform forward computation
output_obj = self._forward_step(
output_obj, moe_loss = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
moe_loss_coeff=moe_loss_coeff,
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
@ -493,6 +501,7 @@ class PipelineScheduler(BaseScheduler):
input_objs.append(input_obj)
output_objs.append(output_obj)
moe_losses.append(moe_loss)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@ -512,7 +521,7 @@ class PipelineScheduler(BaseScheduler):
# Run 1F1B in steady state.
for i in range(num_1f1b_micropairs):
# Perform forward computation
output_obj = self._forward_step(
output_obj, moe_loss = self._forward_step(
engine,
input_obj,
return_tensors,
@ -533,13 +542,15 @@ class PipelineScheduler(BaseScheduler):
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
moe_losses.append(moe_loss)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
moe_loss = moe_losses.pop(0)
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)
if i == (num_1f1b_micropairs - 1):
input_obj = None
@ -563,6 +574,7 @@ class PipelineScheduler(BaseScheduler):
for i in range(num_warmup_microsteps):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
moe_loss = moe_losses.pop(0)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
output_obj_grad = comm.recv_backward(
@ -574,7 +586,7 @@ class PipelineScheduler(BaseScheduler):
output_obj_grad = None
input_obj_grad = self._backward_step(
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss
)
if not gpc.is_first_rank(ParallelMode.PIPELINE):
@ -584,7 +596,7 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -607,9 +619,9 @@ class PipelineScheduler(BaseScheduler):
self.load_batch(engine, data_iter)
if forward_only:
return self._forward_only_step(engine, return_loss, return_output_label)
return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff)
else:
return self._forward_backward_step(engine, return_loss, return_output_label)
return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff)
class InterleavedPipelineScheduler(PipelineScheduler):