mirror of https://github.com/InternLM/InternLM
add no-interleaved & no-overlapped moe pp support
parent
d8e5397159
commit
1accc9f08d
|
@ -0,0 +1,152 @@
|
|||
JOB_NAME = "7b_train"
|
||||
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 16
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
|
||||
# boto3 Ckpt folder format:
|
||||
# import os
|
||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||
CHECKPOINT_EVERY = 50
|
||||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
|
||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
fp16=dict(
|
||||
# the initial loss scale, defaults to 2**16
|
||||
initial_scale=2**16,
|
||||
# the minimum loss scale, defaults to None
|
||||
min_scale=1,
|
||||
# the number of steps to increase loss scale when no overflow occurs
|
||||
growth_interval=1000,
|
||||
),
|
||||
# the multiplication factor for increasing loss scale, defaults to 2
|
||||
growth_factor=2,
|
||||
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||
backoff_factor=0.5,
|
||||
# the maximum loss scale, defaults to None
|
||||
max_scale=2**24,
|
||||
# the number of overflows before decreasing loss scale, defaults to 2
|
||||
hysteresis=2,
|
||||
)
|
||||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
zero_overlap_communication=False,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
|
||||
loss = dict(
|
||||
label_smoothing=0,
|
||||
moe_loss_coeff=0.01,
|
||||
)
|
||||
|
||||
adam = dict(
|
||||
lr=1e-4,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.95,
|
||||
adam_beta2_c=0,
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
lr_scheduler = dict(
|
||||
total_steps=data["total_steps"],
|
||||
init_steps=0, # optimizer_warmup_step
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
last_epoch=-1,
|
||||
)
|
||||
|
||||
beta2_scheduler = dict(
|
||||
init_beta2=adam["adam_beta2"],
|
||||
c=adam["adam_beta2_c"],
|
||||
cur_iter=-1,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
sequence_parallel=False,
|
||||
num_experts=4,
|
||||
moe_use_residual=True,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||
so parameters will be divided within the range of dp.
|
||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
# zero1=8,
|
||||
pipeline=dict(size=4, interleaved_overlap=False),
|
||||
tensor=dict(size=2),
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
cudnn_benchmark = False
|
|
@ -239,7 +239,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
"""
|
||||
return step_id
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0):
|
||||
"""
|
||||
Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
|
@ -259,7 +259,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||
|
||||
self._call_hooks("before_forward", data)
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
output_obj, moe_losses = self._call_engine(engine.model, data)
|
||||
self._call_hooks("after_forward", output_obj)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -272,12 +272,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("after_criterion", loss)
|
||||
|
||||
loss_reduced = loss / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
accum_loss.add_(loss_reduced)
|
||||
output_obj = loss_reduced
|
||||
|
||||
return output_obj
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
return output_obj, moe_loss
|
||||
|
||||
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
|
||||
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
|
||||
"""
|
||||
Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
|
@ -311,6 +313,9 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
||||
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
||||
if moe_loss is not None:
|
||||
moe_loss.backward(retain_graph=True)
|
||||
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj)
|
||||
else:
|
||||
|
@ -329,7 +334,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return input_obj_grad
|
||||
|
||||
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
|
||||
def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
||||
"""
|
||||
This function performs forward only computation process. The scheduling of microbatches is similar to the
|
||||
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
|
||||
|
@ -376,12 +381,13 @@ class PipelineScheduler(BaseScheduler):
|
|||
input_obj = None
|
||||
|
||||
# Perform forward computation
|
||||
output_obj = self._forward_step(
|
||||
output_obj, _ = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -395,7 +401,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, label, accum_loss
|
||||
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
||||
"""
|
||||
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
|
||||
It consists of three stages: warmup, 1F1B, and cooldown.
|
||||
|
@ -441,6 +447,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
moe_losses = []
|
||||
return_tensors = []
|
||||
accum_loss = (
|
||||
torch.zeros(1, device=get_current_device())
|
||||
|
@ -468,12 +475,13 @@ class PipelineScheduler(BaseScheduler):
|
|||
input_obj = None
|
||||
|
||||
# Perform forward computation
|
||||
output_obj = self._forward_step(
|
||||
output_obj, moe_loss = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
@ -493,6 +501,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
moe_losses.append(moe_loss)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
|
@ -512,7 +521,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Run 1F1B in steady state.
|
||||
for i in range(num_1f1b_micropairs):
|
||||
# Perform forward computation
|
||||
output_obj = self._forward_step(
|
||||
output_obj, moe_loss = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
|
@ -533,13 +542,15 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
moe_losses.append(moe_loss)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
moe_loss = moe_losses.pop(0)
|
||||
|
||||
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
|
||||
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)
|
||||
|
||||
if i == (num_1f1b_micropairs - 1):
|
||||
input_obj = None
|
||||
|
@ -563,6 +574,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
for i in range(num_warmup_microsteps):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
moe_loss = moe_losses.pop(0)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_obj_grad = comm.recv_backward(
|
||||
|
@ -574,7 +586,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
output_obj_grad = None
|
||||
|
||||
input_obj_grad = self._backward_step(
|
||||
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
|
||||
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss
|
||||
)
|
||||
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
|
@ -584,7 +596,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, label, accum_loss
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
|
@ -607,9 +619,9 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.load_batch(engine, data_iter)
|
||||
|
||||
if forward_only:
|
||||
return self._forward_only_step(engine, return_loss, return_output_label)
|
||||
return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff)
|
||||
else:
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label)
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff)
|
||||
|
||||
|
||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||
|
|
Loading…
Reference in New Issue