diff --git a/configs/moe_cfg.py b/configs/moe_cfg.py new file mode 100644 index 0000000..6fc41f6 --- /dev/null +++ b/configs/moe_cfg.py @@ -0,0 +1,152 @@ +JOB_NAME = "7b_train" + +SEQ_LEN = 2048 +HIDDEN_SIZE = 4096 +NUM_ATTENTION_HEAD = 32 +MLP_RATIO = 8 / 3 +NUM_LAYER = 16 +VOCAB_SIZE = 103168 + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states). + # load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights. + load_optimizer=True, # Wheter to load optimizer states when continuing training. + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = "/path/to/dataset" +VALID_FOLDER = "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=2, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=50, + pack_sample_into_one=False, + total_steps=50000, + skip_batches="", + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + # train_folder=TRAIN_FOLDER, + # valid_folder=VALID_FOLDER, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + zero_overlap_communication=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, + moe_loss_coeff=0.01, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +model = dict( + checkpoint=False, + num_attention_heads=NUM_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + norm_type="rmsnorm", + layer_norm_epsilon=1e-5, + use_flash_attn=True, + num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + sequence_parallel=False, + num_experts=4, + moe_use_residual=True, +) +""" +zero1 parallel: + 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. + 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +tensor parallel: tensor parallel size, usually the number of GPUs per node. +""" +parallel = dict( + # zero1=8, + pipeline=dict(size=4, interleaved_overlap=False), + tensor=dict(size=2), +) + +cudnn_deterministic = False +cudnn_benchmark = False diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index ac13073..5d7dc04 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -239,7 +239,7 @@ class PipelineScheduler(BaseScheduler): """ return step_id - def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): + def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None, moe_loss_coeff:float=1.0): """ Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. @@ -259,7 +259,7 @@ class PipelineScheduler(BaseScheduler): data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) self._call_hooks("before_forward", data) - output_obj = self._call_engine(engine.model, data) + output_obj, moe_losses = self._call_engine(engine.model, data) self._call_hooks("after_forward", output_obj) if gpc.is_last_rank(ParallelMode.PIPELINE): @@ -272,12 +272,14 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("after_criterion", loss) loss_reduced = loss / self.num_microbatches - accum_loss.add_(loss_reduced.detach()) + accum_loss.add_(loss_reduced) output_obj = loss_reduced - return output_obj + moe_loss = sum(moe_losses) * moe_loss_coeff + moe_loss /= self.num_microbatches + return output_obj, moe_loss - def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad): + def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None): """ Backward step through the passed-in output tensor. If it is the last stage, the output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. @@ -311,6 +313,9 @@ class PipelineScheduler(BaseScheduler): self._call_hooks("before_backward", output_obj, output_obj_grad) with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync): + if moe_loss is not None: + moe_loss.backward(retain_graph=True) + if output_obj_grad is None: engine.backward(output_obj) else: @@ -329,7 +334,7 @@ class PipelineScheduler(BaseScheduler): return input_obj_grad - def _forward_only_step(self, engine, return_loss=True, return_output_label=True): + def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0): """ This function performs forward only computation process. The scheduling of microbatches is similar to the warmup phase, where each microbatch first receives the forward input from the previous stage, then performs @@ -376,12 +381,13 @@ class PipelineScheduler(BaseScheduler): input_obj = None # Perform forward computation - output_obj = self._forward_step( + output_obj, _ = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + moe_loss_coeff=moe_loss_coeff, ) if not gpc.is_last_rank(ParallelMode.PIPELINE): @@ -395,7 +401,7 @@ class PipelineScheduler(BaseScheduler): return output, label, accum_loss - def _forward_backward_step(self, engine, return_loss=True, return_output_label=True): + def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0): """ This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner. It consists of three stages: warmup, 1F1B, and cooldown. @@ -441,6 +447,7 @@ class PipelineScheduler(BaseScheduler): # Input, output tensors only need to be saved when doing backward passes input_objs = [] output_objs = [] + moe_losses = [] return_tensors = [] accum_loss = ( torch.zeros(1, device=get_current_device()) @@ -468,12 +475,13 @@ class PipelineScheduler(BaseScheduler): input_obj = None # Perform forward computation - output_obj = self._forward_step( + output_obj, moe_loss = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss, + moe_loss_coeff=moe_loss_coeff, ) if not gpc.is_last_rank(ParallelMode.PIPELINE): @@ -493,6 +501,7 @@ class PipelineScheduler(BaseScheduler): input_objs.append(input_obj) output_objs.append(output_obj) + moe_losses.append(moe_loss) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to @@ -512,7 +521,7 @@ class PipelineScheduler(BaseScheduler): # Run 1F1B in steady state. for i in range(num_1f1b_micropairs): # Perform forward computation - output_obj = self._forward_step( + output_obj, moe_loss = self._forward_step( engine, input_obj, return_tensors, @@ -533,13 +542,15 @@ class PipelineScheduler(BaseScheduler): # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) + moe_losses.append(moe_loss) # Pop output_obj and output_obj from the start of the list for # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) - input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad) + input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss) if i == (num_1f1b_micropairs - 1): input_obj = None @@ -563,6 +574,7 @@ class PipelineScheduler(BaseScheduler): for i in range(num_warmup_microsteps): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) + moe_loss = moe_losses.pop(0) if not gpc.is_last_rank(ParallelMode.PIPELINE): output_obj_grad = comm.recv_backward( @@ -574,7 +586,7 @@ class PipelineScheduler(BaseScheduler): output_obj_grad = None input_obj_grad = self._backward_step( - engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad + engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss ) if not gpc.is_first_rank(ParallelMode.PIPELINE): @@ -584,7 +596,7 @@ class PipelineScheduler(BaseScheduler): return output, label, accum_loss - def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): + def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff:float=1.0): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -607,9 +619,9 @@ class PipelineScheduler(BaseScheduler): self.load_batch(engine, data_iter) if forward_only: - return self._forward_only_step(engine, return_loss, return_output_label) + return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff) else: - return self._forward_backward_step(engine, return_loss, return_output_label) + return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff) class InterleavedPipelineScheduler(PipelineScheduler):