Merge pull request #2 from blankde/feature_add_moe_pp_zl

feat(moe): moe pipeline support
pull/182/head
Wenwen Qu 2023-08-23 13:51:37 +08:00 committed by GitHub
commit 401796940a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 278 additions and 40 deletions

152
configs/moe_cfg.py Normal file
View File

@ -0,0 +1,152 @@
JOB_NAME = "7b_train"
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 16
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki"
VALID_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
packed_length = 2 * SEQ_LEN,
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50000,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
train_folder=TRAIN_FOLDER,
valid_folder=VALID_FOLDER,
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)
loss = dict(
label_smoothing=0,
moe_loss_coeff=0.1,
)
adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)
lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
model = dict(
checkpoint=False,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
num_experts=4,
moe_use_residual=False,
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
# zero1=4,
pipeline=dict(size=4, interleaved_overlap=False),
# tensor=dict(size=4),
)
cudnn_deterministic = False
cudnn_benchmark = False

View File

@ -127,7 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
if not return_loss: if not return_loss:
loss = None loss = None
return output, loss return output, loss, moe_loss
def forward_backward_step( def forward_backward_step(
self, self,
@ -166,6 +166,7 @@ class NonPipelineScheduler(BaseScheduler):
data, label = batch_data data, label = batch_data
loss = 0 if return_loss else None loss = 0 if return_loss else None
moe_loss = 0 if return_loss else None
outputs = [] outputs = []
labels = [] labels = []
@ -180,12 +181,14 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label) _data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch( _output, _loss, _moe_loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff _data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
) )
if return_loss: if return_loss:
loss += _loss loss += _loss
moe_loss += _moe_loss
if return_output_label: if return_output_label:
outputs.append(_output) outputs.append(_output)
labels.append(_label) labels.append(_label)
@ -193,4 +196,4 @@ class NonPipelineScheduler(BaseScheduler):
if not return_output_label: if not return_output_label:
outputs, labels = None, None outputs, labels = None, None
return outputs, labels, loss return outputs, labels, loss, moe_loss

View File

@ -7,6 +7,7 @@ from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch.cuda import torch.cuda
import torch.distributed as dist
import internlm.core.communication as comm import internlm.core.communication as comm
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
@ -239,7 +240,16 @@ class PipelineScheduler(BaseScheduler):
""" """
return step_id return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): def _forward_step(
self,
engine,
input_obj,
return_tensors,
return_output_label=True,
accum_loss=None,
accum_moe_loss=None,
moe_loss_coeff=1.0,
):
""" """
Forward step for passed-in model. If it is the first stage, the input tensor Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
@ -251,6 +261,7 @@ class PipelineScheduler(BaseScheduler):
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels. return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores. accum_loss (optional): Where accumulated loss stores.
accum_moe_loss (optional): Where accumulated moe loss stores.
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
pipeline stage. pipeline stage.
@ -259,7 +270,7 @@ class PipelineScheduler(BaseScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data) self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model, data) output_obj, moe_losses = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj) self._call_hooks("after_forward", output_obj)
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -275,9 +286,13 @@ class PipelineScheduler(BaseScheduler):
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced output_obj = loss_reduced
return output_obj moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad): return output_obj, moe_loss
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
""" """
Backward step through the passed-in output tensor. If it is the last stage, the Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
@ -311,10 +326,18 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("before_backward", output_obj, output_obj_grad) self._call_hooks("before_backward", output_obj, output_obj_grad)
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync): with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
if output_obj_grad is None: if moe_loss is None:
engine.backward(output_obj) if output_obj_grad is None:
engine.backward(output_obj)
else:
engine.backward_by_grad(output_obj, output_obj_grad)
else: else:
engine.backward_by_grad(output_obj, output_obj_grad) if output_obj_grad is None:
engine.backward(output_obj + moe_loss)
else:
# scale the latent loss
moe_loss = moe_loss * engine.optimizer.loss_scale
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj. # Collect the grad of the input_obj.
input_obj_grad = None input_obj_grad = None
@ -329,7 +352,7 @@ class PipelineScheduler(BaseScheduler):
return input_obj_grad return input_obj_grad
def _forward_only_step(self, engine, return_loss=True, return_output_label=True): def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
""" """
This function performs forward only computation process. The scheduling of microbatches is similar to the This function performs forward only computation process. The scheduling of microbatches is similar to the
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
@ -356,6 +379,7 @@ class PipelineScheduler(BaseScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None else None
) )
accum_moe_loss = torch.zeros(1, device=get_current_device())
# Used for tensor meta information communication # Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape forward_recv_shapes = self.tensor_shape
@ -376,12 +400,14 @@ class PipelineScheduler(BaseScheduler):
input_obj = None input_obj = None
# Perform forward computation # Perform forward computation
output_obj = self._forward_step( output_obj, _ = self._forward_step(
engine, engine,
input_obj, input_obj,
return_tensors, return_tensors,
return_output_label=return_output_label, return_output_label=return_output_label,
accum_loss=accum_loss, accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
) )
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
@ -392,10 +418,14 @@ class PipelineScheduler(BaseScheduler):
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
return output, label, accum_loss if accum_loss is not None:
accum_loss += accum_moe_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True): return output, label, accum_loss, accum_moe_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
""" """
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner. This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
It consists of three stages: warmup, 1F1B, and cooldown. It consists of three stages: warmup, 1F1B, and cooldown.
@ -441,12 +471,14 @@ class PipelineScheduler(BaseScheduler):
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_objs = [] input_objs = []
output_objs = [] output_objs = []
moe_losses = []
return_tensors = [] return_tensors = []
accum_loss = ( accum_loss = (
torch.zeros(1, device=get_current_device()) torch.zeros(1, device=get_current_device())
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True) if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None else None
) )
accum_moe_loss = torch.zeros(1, device=get_current_device())
# Used for tensor meta information communication # Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape forward_recv_shapes = self.tensor_shape
@ -468,12 +500,14 @@ class PipelineScheduler(BaseScheduler):
input_obj = None input_obj = None
# Perform forward computation # Perform forward computation
output_obj = self._forward_step( output_obj, moe_loss = self._forward_step(
engine, engine,
input_obj, input_obj,
return_tensors, return_tensors,
return_output_label=return_output_label, return_output_label=return_output_label,
accum_loss=accum_loss, accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
) )
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
@ -493,6 +527,7 @@ class PipelineScheduler(BaseScheduler):
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
moe_losses.append(moe_loss)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
@ -512,12 +547,14 @@ class PipelineScheduler(BaseScheduler):
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_1f1b_micropairs): for i in range(num_1f1b_micropairs):
# Perform forward computation # Perform forward computation
output_obj = self._forward_step( output_obj, moe_loss = self._forward_step(
engine, engine,
input_obj, input_obj,
return_tensors, return_tensors,
return_output_label=return_output_label, return_output_label=return_output_label,
accum_loss=accum_loss, accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
) )
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -533,13 +570,15 @@ class PipelineScheduler(BaseScheduler):
# Add input_obj and output_obj to end of list. # Add input_obj and output_obj to end of list.
input_objs.append(input_obj) input_objs.append(input_obj)
output_objs.append(output_obj) output_objs.append(output_obj)
moe_losses.append(moe_loss)
# Pop output_obj and output_obj from the start of the list for # Pop output_obj and output_obj from the start of the list for
# the backward pass. # the backward pass.
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
moe_loss = moe_losses.pop(0)
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad) input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)
if i == (num_1f1b_micropairs - 1): if i == (num_1f1b_micropairs - 1):
input_obj = None input_obj = None
@ -563,6 +602,7 @@ class PipelineScheduler(BaseScheduler):
for i in range(num_warmup_microsteps): for i in range(num_warmup_microsteps):
input_obj = input_objs.pop(0) input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0) output_obj = output_objs.pop(0)
moe_loss = moe_losses.pop(0)
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
output_obj_grad = comm.recv_backward( output_obj_grad = comm.recv_backward(
@ -574,17 +614,25 @@ class PipelineScheduler(BaseScheduler):
output_obj_grad = None output_obj_grad = None
input_obj_grad = self._backward_step( input_obj_grad = self._backward_step(
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss
) )
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
return output, label, accum_loss if accum_loss is not None:
accum_loss += accum_moe_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): return output, label, accum_loss, accum_moe_loss
def forward_backward_step(
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -596,7 +644,7 @@ class PipelineScheduler(BaseScheduler):
return_loss (bool, optional): Whether returns the loss value. Default is true. return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned. return_output_label (bool, optional): If False, the output and label won't be returned.
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
""" """
assert ( assert (
@ -607,9 +655,9 @@ class PipelineScheduler(BaseScheduler):
self.load_batch(engine, data_iter) self.load_batch(engine, data_iter)
if forward_only: if forward_only:
return self._forward_only_step(engine, return_loss, return_output_label) return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff)
else: else:
return self._forward_backward_step(engine, return_loss, return_output_label) return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff)
class InterleavedPipelineScheduler(PipelineScheduler): class InterleavedPipelineScheduler(PipelineScheduler):
@ -676,10 +724,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
self._accum_loss = None self._accum_loss = None
self._accum_moe_loss = None
self._return_tensors = None self._return_tensors = None
self._input_objs = [[] for _ in range(num_chunks)] self._input_objs = [[] for _ in range(num_chunks)]
self._output_objs = [[] for _ in range(num_chunks)] self._output_objs = [[] for _ in range(num_chunks)]
self._output_obj_grads = [[] for _ in range(num_chunks)] self._output_obj_grads = [[] for _ in range(num_chunks)]
self._moe_losses = [[] for _ in range(num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
self._output_obj_shapes = [None for _ in range(num_chunks)] self._output_obj_shapes = [None for _ in range(num_chunks)]
@ -687,10 +737,12 @@ class InterleavedPipelineScheduler(PipelineScheduler):
def _clear_state(self) -> None: def _clear_state(self) -> None:
self._accum_loss = None self._accum_loss = None
self._accum_moe_loss = None
self._return_tensors = None self._return_tensors = None
self._input_objs = [[] for _ in range(self._num_chunks)] self._input_objs = [[] for _ in range(self._num_chunks)]
self._output_objs = [[] for _ in range(self._num_chunks)] self._output_objs = [[] for _ in range(self._num_chunks)]
self._output_obj_grads = [[] for _ in range(self._num_chunks)] self._output_obj_grads = [[] for _ in range(self._num_chunks)]
self._moe_losses = [[] for _ in range(self._num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)] self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
self._output_obj_shapes = [None for _ in range(self._num_chunks)] self._output_obj_shapes = [None for _ in range(self._num_chunks)]
@ -712,7 +764,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return move_to_device(micro_batch_data) return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id): def _forward_step(self, engine, chunk_id, moe_loss_coeff=1.0):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
@ -734,7 +786,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data) data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data) self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model[chunk_id], data) output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage # Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel): if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj) output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
@ -754,7 +806,14 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._accum_loss.add_(loss_reduced.detach()) self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced output_obj = loss_reduced
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
if self._accum_moe_loss is not None:
self._accum_moe_loss.add_(moe_loss.detach())
self._output_objs[chunk_id].append(output_obj) self._output_objs[chunk_id].append(output_obj)
self._moe_losses[chunk_id].append(moe_loss)
return output_obj return output_obj
@ -780,8 +839,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
input_obj = self._input_objs[chunk_id].pop(0) input_obj = self._input_objs[chunk_id].pop(0)
output_obj = self._output_objs[chunk_id].pop(0) output_obj = self._output_objs[chunk_id].pop(0)
output_obj_grad = self._output_obj_grads[chunk_id].pop(0) output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
moe_loss = self._moe_losses[chunk_id].pop(0)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad) input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss)
return input_obj_grad return input_obj_grad
@ -813,6 +873,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int, num_warmup_microsteps: int,
receive_extra_backward: bool = False, receive_extra_backward: bool = False,
forward_only: bool = False, forward_only: bool = False,
moe_loss_coeff: float = 1.0,
) -> None: ) -> None:
""" """
Run the warm-up loop and prepare data for the 1F1B stage. Run the warm-up loop and prepare data for the 1F1B stage.
@ -850,12 +911,13 @@ class InterleavedPipelineScheduler(PipelineScheduler):
for k in range(num_warmup_microsteps): for k in range(num_warmup_microsteps):
chunk_id = self._get_chunk_by_microbatch(k) chunk_id = self._get_chunk_by_microbatch(k)
output_obj = self._forward_step(engine, chunk_id) output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff)
if forward_only: if forward_only:
# when forward-only, no need to save tensors for a backward pass # when forward-only, no need to save tensors for a backward pass
self._input_objs[chunk_id].pop() self._input_objs[chunk_id].pop()
self._output_objs[chunk_id].pop() self._output_objs[chunk_id].pop()
self._moe_losses[chunk_id].pop()
if not gpc.is_pipeline_last_stage(): if not gpc.is_pipeline_last_stage():
if isinstance(output_obj, torch.Tensor): if isinstance(output_obj, torch.Tensor):
@ -931,6 +993,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int, num_warmup_microsteps: int,
num_1f1b_micropairs: int, num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False, all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None: ) -> None:
""" """
Run the 1F1B loop with overlap. Run the 1F1B loop with overlap.
@ -960,7 +1023,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True) backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
# 1. Forward pass. # 1. Forward pass.
output_obj = self._forward_step(engine, forward_chunk_id) output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# 2. Check if the backward input is ready. # 2. Check if the backward input is ready.
if backward_async_communicator is not None: if backward_async_communicator is not None:
@ -1045,6 +1108,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int, num_warmup_microsteps: int,
num_1f1b_micropairs: int, num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False, all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None: ) -> None:
""" """
Run the 1F1B loop without overlap. Run the 1F1B loop without overlap.
@ -1066,7 +1130,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# Forward pass. # Forward pass.
forward_microstep_id = k + num_warmup_microsteps forward_microstep_id = k + num_warmup_microsteps
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id) forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
output_obj = self._forward_step(engine, forward_chunk_id) output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# Backward pass. # Backward pass.
backward_microstep_id = k backward_microstep_id = k
@ -1171,7 +1235,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
) )
) )
def _forward_only_step(self, engine: Engine): def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
num_microsteps = self.num_microbatches * self._num_chunks num_microsteps = self.num_microbatches * self._num_chunks
num_warmup_microsteps = num_microsteps num_warmup_microsteps = num_microsteps
@ -1181,9 +1245,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps, num_warmup_microsteps,
receive_extra_backward=False, receive_extra_backward=False,
forward_only=True, forward_only=True,
moe_loss_coeff=moe_loss_coeff,
) )
def _forward_backward_step(self, engine: Engine): def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
# Compute number of warmup and remaining microbatches. # Compute number of warmup and remaining microbatches.
all_warmup_microsteps = False all_warmup_microsteps = False
num_microsteps = self.num_microbatches * self._num_chunks num_microsteps = self.num_microbatches * self._num_chunks
@ -1217,6 +1282,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_microsteps, num_microsteps,
num_warmup_steps, num_warmup_steps,
receive_extra_backward=receive_extra_backward, receive_extra_backward=receive_extra_backward,
moe_loss_coeff=moe_loss_coeff,
) )
# 2. 1F1B # 2. 1F1B
@ -1225,12 +1291,15 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_steps, num_warmup_steps,
num_1f1b_micropairs=num_1f1b_micropairs, num_1f1b_micropairs=num_1f1b_micropairs,
all_warmup_microsteps=all_warmup_microsteps, all_warmup_microsteps=all_warmup_microsteps,
moe_loss_coeff=moe_loss_coeff,
) )
# 3. Cooldown # 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
@ -1254,20 +1323,30 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
self._accum_loss = torch.zeros(1, device=get_current_device()) self._accum_loss = torch.zeros(1, device=get_current_device())
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
if return_output_label: if return_output_label:
self._return_tensors = [] self._return_tensors = []
if forward_only: if forward_only:
self._forward_only_step(engine) self._forward_only_step(engine, moe_loss_coeff)
else: else:
self._forward_backward_step(engine) self._forward_backward_step(engine, moe_loss_coeff)
if return_output_label and len(self._return_tensors) > 0: if return_output_label and len(self._return_tensors) > 0:
output, label = pack_return_tensors(self._return_tensors) output, label = pack_return_tensors(self._return_tensors)
else: else:
output, label = (None, None) output, label = (None, None)
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss
accum_loss = self._accum_loss accum_loss = self._accum_loss
if accum_loss is not None:
accum_loss += self._accum_moe_loss
self._clear_state() self._clear_state()
return output, label, accum_loss return output, label, accum_loss, accum_moe_loss

View File

@ -155,5 +155,5 @@ class Trainer:
Returns: Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss). Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
""" """
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs) output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss return output, label, loss, moe_loss

View File

@ -100,7 +100,7 @@ def evaluate_on_val_dls(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook], metric_hook_list=[val_sche_metric_hook],
): ):
_, _, loss = trainer.execute_schedule( _, _, loss, _ = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False batch, forward_only=True, return_loss=True, return_output_label=False
) )
else: else:
@ -114,7 +114,7 @@ def evaluate_on_val_dls(
grad_accum_batch_size=grad_accum_batch_size, grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook], metric_hook_list=[val_sche_metric_hook],
): ):
_, _, loss = trainer.execute_schedule( _, _, loss, _ = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False batch, forward_only=True, return_loss=True, return_output_label=False
) )
if verbose: if verbose:

View File

@ -346,6 +346,7 @@ def record_current_batch_training_metrics(
trainer, trainer,
start_time, start_time,
loss, loss,
moe_loss,
grad_norm, grad_norm,
metric, metric,
update_panel, update_panel,
@ -389,6 +390,7 @@ def record_current_batch_training_metrics(
"tflops": tflops, "tflops": tflops,
"step": batch_count, "step": batch_count,
"loss": loss.item(), "loss": loss.item(),
"moe_loss": moe_loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu, "tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr, "lr": lr,
"loss_scale": scaler, "loss_scale": scaler,
@ -424,6 +426,7 @@ def record_current_batch_training_metrics(
"num_consumed_tokens": train_state.num_consumed_tokens, "num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm, "grad_norm": grad_norm,
"loss": loss.item(), "loss": loss.item(),
"moe_loss": moe_loss.item(),
"flops": tflops, "flops": tflops,
"tgs": tk_per_gpu, "tgs": tk_per_gpu,
"acc": acc_perplex["acc"], "acc": acc_perplex["acc"],
@ -629,7 +632,7 @@ def main(args):
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule( _, _, loss, moe_loss = trainer.execute_schedule(
batch, batch,
forward_only=False, forward_only=False,
return_loss=True, return_loss=True,
@ -667,6 +670,7 @@ def main(args):
trainer=trainer, trainer=trainer,
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
moe_loss=moe_loss,
grad_norm=np.array(grad_norm_groups), grad_norm=np.array(grad_norm_groups),
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,