merge branch 'feature_add_moe' into feature_add_moe_data

pull/375/head
Wenwen Qu 2023-08-24 17:32:40 +08:00
commit 0e2eb90d22
23 changed files with 1808 additions and 948 deletions

152
configs/moe_cfg.py Normal file
View File

@ -0,0 +1,152 @@
JOB_NAME = "7b_train"
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 16
VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
load_optimizer=True, # Wheter to load optimizer states when continuing training.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki"
VALID_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
packed_length = 2 * SEQ_LEN,
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50000,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
train_folder=TRAIN_FOLDER,
valid_folder=VALID_FOLDER,
)
grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
zero_overlap_communication=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)
loss = dict(
label_smoothing=0,
moe_loss_coeff=0.1,
)
adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)
lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)
beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)
model = dict(
checkpoint=False,
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
num_experts=4,
moe_use_residual=False,
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
# zero1=4,
pipeline=dict(size=4, interleaved_overlap=False),
# tensor=dict(size=4),
)
cudnn_deterministic = False
cudnn_benchmark = False

View File

@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler):
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
moe_loss = sum(moe_losses) * moe_loss_coeff
loss += moe_loss
moe_loss /= scale_loss
loss /= scale_loss
loss += moe_loss
# backward
if not forward_only:
@ -127,7 +128,7 @@ class NonPipelineScheduler(BaseScheduler):
if not return_loss:
loss = None
return output, loss
return output, loss, moe_loss
def forward_backward_step(
self,
@ -166,6 +167,7 @@ class NonPipelineScheduler(BaseScheduler):
data, label = batch_data
loss = 0 if return_loss else None
moe_loss = 0 if return_loss else None
outputs = []
labels = []
@ -180,12 +182,14 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_output, _loss, _moe_loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
)
if return_loss:
loss += _loss
moe_loss += _moe_loss
if return_output_label:
outputs.append(_output)
labels.append(_label)
@ -193,4 +197,4 @@ class NonPipelineScheduler(BaseScheduler):
if not return_output_label:
outputs, labels = None, None
return outputs, labels, loss
return outputs, labels, loss, moe_loss

View File

@ -7,6 +7,7 @@ from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union
import torch.cuda
import torch.distributed as dist
import internlm.core.communication as comm
from internlm.core.context import ParallelMode
@ -130,7 +131,7 @@ class PipelineScheduler(BaseScheduler):
self.dtype = dtype
self._hooks = scheduler_hooks
self.tensor_shape = (
self._tensor_shape = (
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
)
@ -146,6 +147,14 @@ class PipelineScheduler(BaseScheduler):
# cache for the batch data
self.batch_data = None
@property
def tensor_shape(self) -> torch.Size:
return self._tensor_shape
@tensor_shape.setter
def tensor_shape(self, tensor_shape: torch.Size):
self._tensor_shape = tensor_shape
def pre_processing(self, engine):
types = set()
@ -239,7 +248,16 @@ class PipelineScheduler(BaseScheduler):
"""
return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
def _forward_step(
self,
engine,
input_obj,
return_tensors,
return_output_label=True,
accum_loss=None,
accum_moe_loss=None,
moe_loss_coeff=1.0,
):
"""
Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
@ -251,6 +269,7 @@ class PipelineScheduler(BaseScheduler):
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores.
accum_moe_loss (optional): Where accumulated moe loss stores.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
pipeline stage.
@ -259,7 +278,7 @@ class PipelineScheduler(BaseScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model, data)
output_obj, moe_losses = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj)
if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -275,9 +294,13 @@ class PipelineScheduler(BaseScheduler):
accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
return output_obj
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
return output_obj, moe_loss
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
"""
Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
@ -311,10 +334,18 @@ class PipelineScheduler(BaseScheduler):
self._call_hooks("before_backward", output_obj, output_obj_grad)
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
if output_obj_grad is None:
engine.backward(output_obj)
if moe_loss is None:
if output_obj_grad is None:
engine.backward(output_obj)
else:
engine.backward_by_grad(output_obj, output_obj_grad)
else:
engine.backward_by_grad(output_obj, output_obj_grad)
if output_obj_grad is None:
engine.backward(output_obj + moe_loss)
else:
# scale the latent loss
moe_loss = moe_loss * engine.optimizer.loss_scale
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
# Collect the grad of the input_obj.
input_obj_grad = None
@ -329,7 +360,7 @@ class PipelineScheduler(BaseScheduler):
return input_obj_grad
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
"""
This function performs forward only computation process. The scheduling of microbatches is similar to the
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
@ -356,6 +387,7 @@ class PipelineScheduler(BaseScheduler):
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
accum_moe_loss = torch.zeros(1, device=get_current_device())
# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
@ -376,12 +408,14 @@ class PipelineScheduler(BaseScheduler):
input_obj = None
# Perform forward computation
output_obj = self._forward_step(
output_obj, _ = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
@ -392,10 +426,14 @@ class PipelineScheduler(BaseScheduler):
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
return output, label, accum_loss
if accum_loss is not None:
accum_loss += accum_moe_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
return output, label, accum_loss, accum_moe_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
"""
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
It consists of three stages: warmup, 1F1B, and cooldown.
@ -441,12 +479,14 @@ class PipelineScheduler(BaseScheduler):
# Input, output tensors only need to be saved when doing backward passes
input_objs = []
output_objs = []
moe_losses = []
return_tensors = []
accum_loss = (
torch.zeros(1, device=get_current_device())
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
accum_moe_loss = torch.zeros(1, device=get_current_device())
# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
@ -468,12 +508,14 @@ class PipelineScheduler(BaseScheduler):
input_obj = None
# Perform forward computation
output_obj = self._forward_step(
output_obj, moe_loss = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
@ -493,6 +535,7 @@ class PipelineScheduler(BaseScheduler):
input_objs.append(input_obj)
output_objs.append(output_obj)
moe_losses.append(moe_loss)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@ -512,12 +555,14 @@ class PipelineScheduler(BaseScheduler):
# Run 1F1B in steady state.
for i in range(num_1f1b_micropairs):
# Perform forward computation
output_obj = self._forward_step(
output_obj, moe_loss = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
accum_moe_loss=accum_moe_loss,
moe_loss_coeff=moe_loss_coeff,
)
if gpc.is_last_rank(ParallelMode.PIPELINE):
@ -533,13 +578,15 @@ class PipelineScheduler(BaseScheduler):
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
moe_losses.append(moe_loss)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
moe_loss = moe_losses.pop(0)
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)
if i == (num_1f1b_micropairs - 1):
input_obj = None
@ -563,6 +610,7 @@ class PipelineScheduler(BaseScheduler):
for i in range(num_warmup_microsteps):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
moe_loss = moe_losses.pop(0)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
output_obj_grad = comm.recv_backward(
@ -574,17 +622,25 @@ class PipelineScheduler(BaseScheduler):
output_obj_grad = None
input_obj_grad = self._backward_step(
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss
)
if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
return output, label, accum_loss
if accum_loss is not None:
accum_loss += accum_moe_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
return output, label, accum_loss, accum_moe_loss
def forward_backward_step(
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -596,7 +652,7 @@ class PipelineScheduler(BaseScheduler):
return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
"""
assert (
@ -607,9 +663,9 @@ class PipelineScheduler(BaseScheduler):
self.load_batch(engine, data_iter)
if forward_only:
return self._forward_only_step(engine, return_loss, return_output_label)
return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff)
else:
return self._forward_backward_step(engine, return_loss, return_output_label)
return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff)
class InterleavedPipelineScheduler(PipelineScheduler):
@ -676,21 +732,35 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
self._accum_loss = None
self._accum_moe_loss = None
self._return_tensors = None
self._input_objs = [[] for _ in range(num_chunks)]
self._output_objs = [[] for _ in range(num_chunks)]
self._output_obj_grads = [[] for _ in range(num_chunks)]
self._moe_losses = [[] for _ in range(num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
self._output_obj_shapes = [None for _ in range(num_chunks)]
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
@property
def tensor_shape(self) -> torch.Size:
return self._tensor_shape
@tensor_shape.setter
def tensor_shape(self, tensor_shape: torch.Size):
self._tensor_shape = tensor_shape
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
def _clear_state(self) -> None:
self._accum_loss = None
self._accum_moe_loss = None
self._return_tensors = None
self._input_objs = [[] for _ in range(self._num_chunks)]
self._output_objs = [[] for _ in range(self._num_chunks)]
self._output_obj_grads = [[] for _ in range(self._num_chunks)]
self._moe_losses = [[] for _ in range(self._num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
self._output_obj_shapes = [None for _ in range(self._num_chunks)]
@ -712,7 +782,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id):
def _forward_step(self, engine, chunk_id, moe_loss_coeff=1.0):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
@ -734,7 +804,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model[chunk_id], data)
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
@ -754,7 +824,14 @@ class InterleavedPipelineScheduler(PipelineScheduler):
self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
moe_loss = sum(moe_losses) * moe_loss_coeff
moe_loss /= self.num_microbatches
if self._accum_moe_loss is not None:
self._accum_moe_loss.add_(moe_loss.detach())
self._output_objs[chunk_id].append(output_obj)
self._moe_losses[chunk_id].append(moe_loss)
return output_obj
@ -780,8 +857,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
input_obj = self._input_objs[chunk_id].pop(0)
output_obj = self._output_objs[chunk_id].pop(0)
output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
moe_loss = self._moe_losses[chunk_id].pop(0)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss)
return input_obj_grad
@ -813,6 +891,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int,
receive_extra_backward: bool = False,
forward_only: bool = False,
moe_loss_coeff: float = 1.0,
) -> None:
"""
Run the warm-up loop and prepare data for the 1F1B stage.
@ -850,12 +929,13 @@ class InterleavedPipelineScheduler(PipelineScheduler):
for k in range(num_warmup_microsteps):
chunk_id = self._get_chunk_by_microbatch(k)
output_obj = self._forward_step(engine, chunk_id)
output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff)
if forward_only:
# when forward-only, no need to save tensors for a backward pass
self._input_objs[chunk_id].pop()
self._output_objs[chunk_id].pop()
self._moe_losses[chunk_id].pop()
if not gpc.is_pipeline_last_stage():
if isinstance(output_obj, torch.Tensor):
@ -931,6 +1011,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int,
num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None:
"""
Run the 1F1B loop with overlap.
@ -960,7 +1041,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
# 1. Forward pass.
output_obj = self._forward_step(engine, forward_chunk_id)
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# 2. Check if the backward input is ready.
if backward_async_communicator is not None:
@ -1045,6 +1126,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps: int,
num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False,
moe_loss_coeff: float = 1.0,
) -> None:
"""
Run the 1F1B loop without overlap.
@ -1066,7 +1148,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# Forward pass.
forward_microstep_id = k + num_warmup_microsteps
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
output_obj = self._forward_step(engine, forward_chunk_id)
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
# Backward pass.
backward_microstep_id = k
@ -1171,7 +1253,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
)
)
def _forward_only_step(self, engine: Engine):
def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
num_microsteps = self.num_microbatches * self._num_chunks
num_warmup_microsteps = num_microsteps
@ -1181,9 +1263,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_microsteps,
receive_extra_backward=False,
forward_only=True,
moe_loss_coeff=moe_loss_coeff,
)
def _forward_backward_step(self, engine: Engine):
def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
# Compute number of warmup and remaining microbatches.
all_warmup_microsteps = False
num_microsteps = self.num_microbatches * self._num_chunks
@ -1217,6 +1300,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_microsteps,
num_warmup_steps,
receive_extra_backward=receive_extra_backward,
moe_loss_coeff=moe_loss_coeff,
)
# 2. 1F1B
@ -1225,12 +1309,15 @@ class InterleavedPipelineScheduler(PipelineScheduler):
num_warmup_steps,
num_1f1b_micropairs=num_1f1b_micropairs,
all_warmup_microsteps=all_warmup_microsteps,
moe_loss_coeff=moe_loss_coeff,
)
# 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
def forward_backward_step(
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
@ -1250,24 +1337,36 @@ class InterleavedPipelineScheduler(PipelineScheduler):
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
gpc.set_virtual_pipeline_parallel_rank(0)
self.load_batch(engine, data_iter)
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
self._accum_loss = torch.zeros(1, device=get_current_device())
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
if return_output_label:
self._return_tensors = []
if forward_only:
self._forward_only_step(engine)
self._forward_only_step(engine, moe_loss_coeff)
else:
self._forward_backward_step(engine)
self._forward_backward_step(engine, moe_loss_coeff)
if return_output_label and len(self._return_tensors) > 0:
output, label = pack_return_tensors(self._return_tensors)
else:
output, label = (None, None)
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss
accum_loss = self._accum_loss
if accum_loss is not None:
accum_loss += self._accum_moe_loss
self._clear_state()
return output, label, accum_loss
return output, label, accum_loss, accum_moe_loss

View File

@ -155,5 +155,5 @@ class Trainer:
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
"""
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
return output, label, loss, moe_loss

View File

@ -108,67 +108,96 @@ def args_sanity_check():
logger.info(f"valid_every: {data.valid_every}")
# processing the checkpoint config
if "enable_save_ckpt" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("enable_save_ckpt", False)
ckpt = gpc.config.ckpt
if "enable_save_ckpt" not in ckpt:
ckpt._add_item("enable_save_ckpt", False)
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
gpc.config.ckpt._add_item("checkpoint_every", float("inf"))
# Saving checkpoint args.
if ckpt.enable_save_ckpt:
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
assert ckpt.checkpoint_every > 0
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
if "load_optimizer" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_optimizer", True)
if "async_upload" not in ckpt:
ckpt._add_item("async_upload", False) # async defalut is False.
else:
if ckpt.async_upload:
assert "save_ckpt_folder" in ckpt
if "boto3:" not in ckpt.save_ckpt_folder:
if gpc.is_rank_for_log():
logger.warning(
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
)
ckpt.async_upload = False
else:
if "async_upload_tmp_folder" not in ckpt:
ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
if "save_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("save_ckpt_folder", None)
if not ckpt.async_upload:
ckpt._add_item("async_upload_tmp_folder", None)
if "load_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_ckpt_folder", None)
if "snapshot_ckpt_folder" not in ckpt:
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
if "load_model_only_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("load_model_only_folder", None)
if "oss_snapshot_freq" not in ckpt:
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
else:
ckpt._add_item("checkpoint_every", float("inf"))
ckpt._add_item("oss_snapshot_freq", float("inf"))
ckpt._add_item("save_ckpt_folder", None)
ckpt._add_item("async_upload", False)
ckpt._add_item("async_upload_tmp_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
ckpt._add_item("snapshot_ckpt_folder", None)
if "async_upload" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload", False)
# Loading checkpoint args.
if "load_model_only_folder" not in ckpt:
ckpt._add_item("load_model_only_folder", None)
if "async_upload_tmp_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
if "load_ckpt_folder" not in ckpt:
ckpt._add_item("load_ckpt_folder", None)
if gpc.config.ckpt.async_upload:
assert "save_ckpt_folder" in gpc.config.ckpt
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
if gpc.is_rank_for_log():
logger.warning("Storing ckpt on file system does not support asynchronous storage, will use sync save!")
gpc.config.ckpt.async_upload = False
if "load_optimizer" not in ckpt:
ckpt._add_item("load_optimizer", True)
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
if "stop_file_path" not in ckpt:
ckpt._add_item("stop_file_path", None)
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
assert gpc.config.ckpt.oss_snapshot_freq > 0
if "load_given_ckpt" not in ckpt:
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
# to auto-load latest checkpoint.
ckpt._add_item("load_given_ckpt", False)
assert not (
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
if ckpt.load_given_ckpt:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
logger.warning(
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
)
ckpt.load_model_only_folder = None
if gpc.is_rank_for_log():
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
logger.info(f"async_upload: {gpc.config.ckpt.async_upload}")
if gpc.config.ckpt.async_upload:
logger.info(f"async_upload_tmp_folder: {gpc.config.ckpt.async_upload_tmp_folder}")
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
# initialization storage manager
init_storage_manager(gpc.config.ckpt)
init_storage_manager(ckpt)
# tensorboard writer config
if "enable_tb" not in gpc.config:
gpc.config._add_item("enable_tb", True)
if "tensorboard_folder" not in gpc.config:
gpc.config._add_item("tensorboard_folder", None)
gpc.config._add_item(
"tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
)
if "resume_tb_folder" not in gpc.config:
gpc.config._add_item("resume_tb_folder", None)
gpc.config._add_item(
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
)
# cudnn
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
@ -236,11 +265,13 @@ def args_sanity_check():
# process the model config
if "use_flash_attn" not in gpc.config.model:
gpc.config.model._add_item("use_flash_attn", True)
if "sequence_parallel" not in gpc.config.model:
gpc.config.model._add_item("sequence_parallel", False)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)
else:
assert not (
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
), "sequence parallel does not support use_flash_attn=False"
# feishu webhook address for alerting

View File

@ -7,6 +7,7 @@ import rotary_emb
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
from torch import Tensor, nn
@ -56,7 +57,7 @@ class Embedding1D(nn.Module):
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
if gpc.config.model.sequence_parallel:
if gpc.config.parallel.sequence_parallel:
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
return output
@ -111,6 +112,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
class RotaryEmbedding(torch.nn.Module):

View File

@ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear):
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
)
@ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
weight,
self.bias,
process_group=self.process_group,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
)
@ -173,7 +173,7 @@ class FeedForward(nn.Module):
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
@ -182,7 +182,7 @@ class FeedForward(nn.Module):
hidden_features,
process_group,
bias,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)
@ -191,7 +191,7 @@ class FeedForward(nn.Module):
out_features,
process_group,
bias=bias,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

View File

@ -393,7 +393,7 @@ class PackedFlashInternLm1D(nn.Module):
max_position_embeddings=-1,
process_group=gpc.get_group(ParallelMode.TENSOR),
padding_idx=None,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
device=device,
dtype=dtype,
)

View File

@ -74,7 +74,7 @@ class MoE(torch.nn.Module):
drop_tokens: bool = True,
use_rts: bool = True,
using_default_moe: bool = True,
use_residual=True,
use_residual=False,
residual_mlp=None,
):

View File

@ -82,7 +82,7 @@ class MHA(nn.Module):
3 * embed_dim,
process_group,
bias=True,
sequence_parallel=gpc.config.model.sequence_parallel,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
) # according to https://spaces.ac.cn/archives/9577
@ -95,7 +95,11 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
self.out_proj = RowParallelLinearTorch(
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
embed_dim,
embed_dim,
process_group,
sequence_parallel=gpc.config.parallel.sequence_parallel,
**factory_kwargs,
)
# need to assign tp attribute so that internlm know it is tensor parallel module
if gpc.get_world_size(ParallelMode.TENSOR) > 1:

View File

@ -356,6 +356,8 @@ class TopKGate(Module):
# Only top-1 and top-2 are supported at the moment.
if k not in (1, 2):
raise ValueError("Only top-1 and top-2 gatings are supported.")
# TODO: can we use tensor parallel here?
# Deepspeed's mechisms, alway use fp32
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.k = k
self.capacity_factor = capacity_factor

View File

@ -3,6 +3,7 @@
import math
from functools import partial
from itertools import product
import torch
import torch.distributed as dist
@ -20,6 +21,7 @@ from internlm.solver.optimizer.store import (
)
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
ParamBcastSyncHandler,
flatten,
get_grad_accumulate_object,
has_inf_or_nan,
@ -88,10 +90,10 @@ class HybridZeroOptimizer(BaseOptimizer):
self,
optimizer: Optimizer,
cpu_offload=False,
overlap_broadcast=False,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
has_moe: bool = False,
param_bcast_sync_handler: ParamBcastSyncHandler = None,
):
# DynamicGradScaler related args
if gpc.config.model.dtype is torch.float32:
@ -163,7 +165,9 @@ class HybridZeroOptimizer(BaseOptimizer):
+ f"zo-{self._zero_local_rank}.pt"
)
self.params_per_rank_id_dict = []
self.overlap_broadcast = overlap_broadcast
self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_communication:
assert self._param_bcast_sync_handler is not None
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -238,6 +242,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# communication-computation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
@ -284,8 +290,10 @@ class HybridZeroOptimizer(BaseOptimizer):
global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
rank_to_go = numel_per_rank.index(min(numel_per_rank))
if self._overlap_communication:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel()
@ -322,7 +330,9 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
self._store_and_try_reduce_grads_by_bucket,
param=param,
reduce_rank=reduce_rank,
)
# define hook
@ -416,16 +426,16 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
if self._overlap_communication:
stream = self._comm_stream
stream.synchronize()
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
with torch.cuda.stream(self._comm_stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=dp_parallel_mode
tensor=flat,
dtype=self.dtype,
dst_rank=reduce_rank,
parallel_mode=dp_parallel_mode,
)
# update the reduced tensor
@ -616,7 +626,10 @@ class HybridZeroOptimizer(BaseOptimizer):
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
send_alert_message(
address=gpc.config.alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, None
@ -678,37 +691,42 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
# TODO: support broadcast overlap
self.broadcast_params(overlap=False)
with torch.cuda.stream(self._comm_stream):
self.broadcast_params()
timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
def broadcast_params(self, overlap=False):
def broadcast_params(self):
handles = []
for group_id in range(self.num_param_groups):
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
if self._is_moe_group(self.optim.param_groups[group_id]):
continue
for rank in range(self._zero_world_size):
# The following operations are performed only on the rank to which parameters are assigned.
if rank not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
)
handles.append(handle)
# The following operations are performed only on the rank to which parameters are assigned.
if rank in self.param_group_no_params_ranks[group_id]:
continue
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(ParallelMode.ZERO1),
async_op=True,
)
if not overlap:
for handle in handles:
handle.wait()
else:
return handles
if self._overlap_communication:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)
for handle in handles:
handle.wait()
##################
# FP16 Utilities #
@ -726,7 +744,11 @@ class HybridZeroOptimizer(BaseOptimizer):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
dist.all_reduce(
self._found_overflow,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.GLOBAL),
)
return self._found_overflow.item() > 0

View File

@ -3,15 +3,18 @@
import math
from abc import ABC, abstractmethod
from typing import Dict, Optional
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from torch import Tensor, nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_model_parallel_parameter
@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor):
def split_half_float_double(tensor_list):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for _, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
buckets.append(bucket)
dtype_buckets = {
"torch.cuda.HalfTensor": [],
"torch.cuda.FloatTensor": [],
"torch.cuda.DoubleTensor": [],
"torch.cuda.BFloat16Tensor": [],
}
for t in tensor_list:
dtype = t.type()
if dtype in dtype_buckets:
dtype_buckets[dtype].append(t)
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
return buckets
@ -184,7 +194,10 @@ def calc_l2_norm(grads):
if APEX_AVAILABLE:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False, # no per-parameter norm
)
else:
norm, _ = multi_tensor_l2norm_torch(grads, False)
@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
dist.all_reduce(
total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
dist.all_reduce(
total_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.MODEL),
)
# This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group.
@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
self._scale = self._scale.fill_(state_dict["_scale"])
self._growth_step = state_dict["_growth_step"]
self._hysteresis_step = state_dict["_hysteresis_step"]
class ParamBcastSyncHandler:
"""
Model Partition Handler for overlap broadcast with forward
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
self._param_to_rank = dict() # <key: param> <value: rank)>
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
total_param_num = sum(p.numel() for p in model.parameters())
avg_param_num = total_param_num * 1.0 // zero1_size
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
# record the parameters to transformer/embeding/head/norm block
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
# record the block that a parameter belongs to
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
alloc_num = 0
rank_to_go = 0
# process the parameters in block_to_param sequencially,
# allocate each parameter to a local rank of ParallelMode.ZERO1,
# NOTE that we do NOT consider following scenarios:
# 1) whether a parameter is trainable;
# 2) paramters maybe in different optimizer group
for block, params in self._block_to_param.items():
# allocate a model block to a local rank of ParallelMode.ZERO1
self._block_to_rank[block] = [rank_to_go]
for p in params:
alloc_num = alloc_num + p.numel()
# in this case, allocate the param to next rank if possible
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
rank_to_go = rank_to_go + 1
alloc_num = 0
self._block_to_rank[block].append(rank_to_go)
# allocate a parameter to a local rank of ParallelMode.ZERO1
self._param_to_rank[p] = rank_to_go
# initialize an empty list for _bcast_handles of each rank
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
self._bcast_handles[rank] = []
# register_forward_pre_hook for transformer/embeding/norm/xxx block
self._register_sync_parameters_hook()
def _register_sync_parameters_hook(self) -> None:
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
bcast_handles = []
# gather all required broadcast hanles into a list
for rank in self._block_to_rank[model]:
bcast_handles.extend(self._bcast_handles[rank])
# need to clear _bcast_handles since they would be processed later
self._bcast_handles[rank] = []
# wait all required broadcast handles to be completed
for handle in bcast_handles:
handle.wait()
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for block, _ in self._block_to_rank.items():
block.register_forward_pre_hook(partial(_pre_forward_hook))
def get_rank_by_param(self, param) -> int:
return self._param_to_rank[param]
def add_bcast_handle(self, rank, handle) -> None:
self._bcast_handles[rank].append(handle)

View File

@ -0,0 +1,21 @@
from .training_internlm import (
get_train_data_loader,
get_validation_data_loader,
initialize_distributed_env,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
)
__all__ = [
"get_train_data_loader",
"get_validation_data_loader",
"initialize_distributed_env",
"initialize_llm_profile",
"initialize_model",
"initialize_optimizer",
"load_new_batch",
"record_current_batch_training_metrics",
]

View File

@ -0,0 +1,447 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
from functools import partial
from typing import Callable, Iterable, Union
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
from internlm.data.dataset import get_dataset_dict
from internlm.data.dummy_dataset import RandomDataset
from internlm.data.packed_dataset import (
PackedDataset,
PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.moe import create_moe_param_groups, has_moe_layers
from internlm.monitor import set_env_var
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile, get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import (
is_no_pp_or_last_stage,
sync_model_param_with_ep,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
logger = get_logger(__file__)
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
internlm.launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
internlm.launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
def initialize_model():
"""
Initialize model.
Returns: The neural network model to be trained or evaluated.
"""
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
if isinstance(model, nn.ModuleList):
model = nn.ModuleList(
[
NaiveAMPModel(
model=_m,
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
for _m in model
]
)
else:
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
# does not influence the model weights in optimizer be different with the origin parameters.
sync_model_param_with_ep(model)
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
# the same across tensor parallelism.
sync_model_param_within_tp(model)
return model
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize optimizer.
Args:
model (torch.nn.Module): Your model instance to be trained or evaluated.
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
param_bcast_sync_handler = ParamBcastSyncHandler(model)
adam_cfg = gpc.config.adam
if gpc.config.model.num_experts > 1:
params = create_moe_param_groups(model, adam_cfg.weight_decay)
else:
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
naive_optimizer = torch.optim.AdamW(
params=params,
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,
)
has_moe = has_moe_layers(model)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
has_moe=has_moe,
param_bcast_sync_handler=param_bcast_sync_handler,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
return optimizer, beta2_scheduler, lr_scheduler
def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
"""
Generate and return the training data loader.
Returns: A tuple of (train_dl, dataset_types).
"""
# Get the dataset types
dataset_types = None
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder
if not train_folder:
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
if dataset_generate_func is not None:
train_ds = dataset_generate_func()
else:
train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder,
packed_length=data_cfg.packed_length,
max_length_per_sample=data_cfg.seq_len,
show_progress=dist.get_rank() == 0,
min_length=data_cfg.min_length,
min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one,
)
if dataset_generate_func is None or not train_folder:
# partition already completed
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
# Create the training dataset sampler
train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=1024,
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)
if dataset_generate_func is None or not train_folder:
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader
train_dl = DataLoader(
dataset=train_ds,
batch_sampler=train_sampler,
num_workers=num_worker,
pin_memory=True,
collate_fn=train_collate_fn,
persistent_workers=num_worker > 0,
)
return train_dl, dataset_types
def get_validation_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
):
"""Generate and return the validation data loader."""
data_cfg = gpc.config.data
if not data_cfg.valid_folder:
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
else:
if dataset_generate_func is not None:
assert val_collate_fn and dataloader_func is not None
val_ds = dataset_generate_func()
else:
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
if not isinstance(val_ds, dict):
val_ds = {"val": val_ds}
if val_collate_fn is None or not data_cfg.valid_folder:
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
val_dls = {}
for val_name, ds in val_ds.items():
if dataloader_func and data_cfg.valid_folder is not None:
val_dls[val_name] = dataloader_func(dataset=ds, collate_fn=val_collate_fn)
if gpc.is_rank_for_log():
logger.info(
f"load validation dataset {val_name} with valid batch size {str(data_cfg.valid_micro_num)} and "
f"{ds.size} Byte samples."
)
else:
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped
batch_size = min(
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
)
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.")
continue
val_dls[val_name] = get_dpsampler_dataloader(
ds,
shuffle=False,
num_workers=num_worker,
batch_size=batch_size,
collate_fn=val_collate_fn,
drop_last=True,
) # drop_last=True, otherwise it may cause problems in the last batch
if gpc.is_rank_for_log():
logger.info(
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
f"samples {str(len(val_dls[val_name]))}."
)
return val_dls
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
Args:
train_dl (torch.utils.data.DataLoader): Dataloader for training.
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
train_state (TrainState): Current training state.
Returns: A batch data and the updated train_iter.
"""
timer("batch-gen").start()
try:
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
if hasattr(train_state, "batch_sampler_iter"):
next(train_state.batch_sampler_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)
train_state.num_consumed_samples_in_epoch = 0
if hasattr(train_state, "batch_sampler"):
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
next(train_state.batch_sampler_iter)
timer("batch-gen").stop()
if batch[0].get("type_ids", None) is not None:
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
return batch, train_iter
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
"""Initialize and return the profiler context manager instance."""
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
llm_profile = torch.profiler.profile
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
else:
llm_profile = DummyProfile
return llm_profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
),
with_stack=True,
with_modules=True,
)
def record_current_batch_training_metrics(
get_tflops_func,
logger,
writer,
success_update,
batch_count,
batch,
train_state,
optimizer,
beta2_scheduler,
trainer,
start_time,
loss,
moe_loss,
grad_norm,
metric,
update_panel,
):
"""
Print some training metrics of current batch.
"""
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"]
if hasattr(trainer.engine.optimizer, "grad_scaler"):
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
num_tokens_in_batch = batch[1].nelement()
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
tk_per_gpu = 0
tk_per_gpu = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
/ (time.time() - start_time),
2,
)
tflops = get_tflops_func((time.time() - start_time))
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"moe_loss": moe_loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,
"loss_scale": scaler,
"grad_norm": grad_norm,
}
infos["micro_num"] = len(batch[1])
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
infos["largest_length"] = max_length_in_batch # the longest input
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
infos["smallest_batch"] = min_samples_in_batch
infos["adam_beta2"] = beta2_scheduler.get_beta2()
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
infos["fwd_bwd_time"] = fwd_bwd_time
for key, value in acc_perplex.items():
infos[key] = value
line = ""
for key, value in infos.items():
line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if update_panel:
logger.info(
line,
extra={
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"moe_loss": moe_loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
)
else:
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())

View File

@ -218,3 +218,21 @@ def get_megatron_flops(
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops
class DummyProfile:
"""
Dummy Profile.
"""
def __init__(self, *args, **kwargs) -> None:
pass
def __enter__(self):
return self
def __exit__(self, a, b, c):
pass
def step(self):
pass

View File

@ -50,6 +50,16 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.parallel.sequence_parallel
try:
gpc.config.parallel.sequence_parallel = False
yield
finally:
gpc.config.parallel.sequence_parallel = prev_mode
def evaluate_on_val_dls(
trainer,
val_dls,
@ -57,110 +67,102 @@ def evaluate_on_val_dls(
logger,
step_count,
update_panel: bool = False,
streaming: bool = False,
):
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
with switch_sequence_parallel_mode():
torch.cuda.empty_cache()
trainer.eval()
verbose = gpc.is_rank_for_log()
data_cfg = gpc.config.data
for val_name, val_dl in val_dls.items():
if len(val_dl) == 0 and verbose:
logger.info(f"Validation dataset: {val_name} is empty")
continue
for val_name, val_dl in val_dls.items():
if len(val_dl) == 0 and verbose and not streaming:
logger.info(f"Validation dataset: {val_name} is empty")
continue
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
)
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl),
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
val_loss = 0
val_idx = -1
for val_idx, batch in tqdm(
enumerate(val_dl),
desc="Val.",
total=len(val_dl) if not streaming else None,
position=1,
disable=not verbose,
leave=False,
):
with torch.inference_mode():
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, _ = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss, _ = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
dist.barrier()
val_res = val_metric.get_metric()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra={
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
},
)
else:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
)
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_batch_size = data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
grad_accum_batch_size=grad_accum_batch_size,
metric_hook_list=[val_sche_metric_hook],
):
_, _, loss = trainer.execute_schedule(
batch, forward_only=True, return_loss=True, return_output_label=False
)
if verbose:
val_loss += loss.item()
assert val_idx != -1
trainer.train()
torch.cuda.empty_cache()
dist.barrier()
val_res = val_metric.get_metric()
if verbose and len(val_dl) != 0:
val_loss = val_loss / (val_idx + 1 + 1e-6)
infos = {
"step": step_count,
f"val/{val_name}_loss": val_loss,
f"val/{val_name}_acc": val_res["acc"],
f"val/{val_name}_plex": val_res["perplexity"],
}
for key, value in infos.items():
writer.add_scalar(key=key, value=value, step=step_count)
if update_panel:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
extra={
"step": step_count,
"val_loss": val_loss,
"val_acc": val_res["acc"],
"val_perplexity": val_res["perplexity"],
},
)
else:
logger.info(
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
)
trainer.train()
torch.cuda.empty_cache()
dist.barrier()
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.model.sequence_parallel
try:
gpc.config.model.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode

View File

@ -14,18 +14,19 @@ class _Timer:
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
self.stream = torch.cuda.current_stream()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.stream.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.stream.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False

View File

@ -2,8 +2,12 @@
# -*- encoding: utf-8 -*-
import copy
import fcntl
import os
import re
import socket
import time
from collections import defaultdict
from enum import Enum
from typing import Dict
@ -12,6 +16,8 @@ import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
from internlm.model.moe import MoE
from internlm.monitor import send_alert_message
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
@ -25,8 +31,6 @@ from internlm.utils.storage_manager import (
logger = get_logger(__file__)
quit_signal_handler = None
class CheckpointType(Enum):
NORMAL_CHECKPOINT = 1
@ -69,6 +73,8 @@ def save_model_checkpoint(folder, model):
"""
states = model.state_dict()
# get non-moe parameters
states = get_non_moe_state_dict(states)
topo = get_model_topology(model)
if folder is not None:
@ -92,6 +98,9 @@ def save_model_checkpoint(folder, model):
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
# move the judgement logic into save_moe_checkpoint(.)
try_save_moe_checkpoint(folder, model)
torch.distributed.barrier()
@ -128,6 +137,18 @@ def load_model_checkpoint(folder, model):
fp = os.path.join(folder, should_load_name)
states = llm_load(fp, map_location=get_current_device())
"""
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
# gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make
# the gate parameters to be float16 before forward.
for key in list(states.keys()):
if 'moe_layer.gate.wg.weight' in key:
states[key] = states[key].float()
print("load: ", states[key].float(),flush=True)
"""
try_load_moe_checkpoint(folder, model, states)
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
if len(missing_k) != 0:
logger.warning(f"Warning: missing keys {missing_k}")
@ -139,6 +160,58 @@ def load_model_checkpoint(folder, model):
torch.cuda.empty_cache()
def try_save_moe_checkpoint(folder, model):
# Using layer_#_expert_# to save the model's expert state_dicta hack.
moe_layer_id = 0
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# get all moe parameters
moe_state_dict = {}
for n, p in module.state_dict().items():
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
moe_state_dict[n_module + "." + n] = p
moe_str_prefix = ".moe_layer.experts.experts."
# Reorder the moe name rank, so that each checkpoint only has one expert
experts_state_dict = defaultdict(dict)
for key in list(moe_state_dict.keys()):
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
local_expert_id = None
if not m:
logger.warning(f"No expert found in key {key}.")
else:
local_expert_id = m.group(1)
global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")
# truncating extra tensor (shared) storage
truncated = moe_state_dict.pop(key).clone().detach()
experts_state_dict[str(global_expert_id)][expert_key] = truncated
# let save the moe parameters
for global_expert_id, expert_state_dict in experts_state_dict.items():
# save the moe parameters
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=expert_state_dict)
moe_layer_id += 1
def get_non_moe_state_dict(full_state_dict):
"""
Get the state dict of the non-moe layers
"""
for key in list(full_state_dict.keys()):
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
full_state_dict.pop(key)
return full_state_dict
def save_optimizer_checkpoint(optim, state_path):
"""Store the state of the optimizer to the local file system or remote OSS.
@ -167,42 +240,25 @@ def save_optimizer_checkpoint(optim, state_path):
llm_save(os.path.join(state_path, fp), states)
def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
"""
Save checkpoint to the given folder path.
"""
start = time.time()
torch.distributed.barrier()
folder = os.path.join(folder, str(train_state.step_count))
logger.info(
f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..."
)
timer("save-model").start()
save_model_checkpoint(folder=folder, model=model)
timer("save-model").stop()
timer("save-optimizer").start()
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()
if gpc.is_rank_for_log():
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
sampler_state = train_state.batch_sampler.state_dict()
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
if model_config is not None:
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
torch.distributed.barrier()
if gpc.is_rank_for_log():
timer.log(["save-model", "save-optimizer"], logger=logger)
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
def try_load_moe_checkpoint(folder, model, state_dict):
moe_layer_id = 0
for _, module in model.named_modules():
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
num_local_experts = module.num_local_experts
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
fp = os.path.join(folder, fn)
expert_state_dict = llm_load(fp, map_location=get_current_device())
# Updating global -> local expert ids
moe_str_prefix = ".moe_layer.experts.experts."
for key in list(expert_state_dict.keys()):
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
moe_layer_id += 1
def load_optimizer_checkpoint(folder, optim):
@ -304,19 +360,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
logger.info(f"reload load_scheduler:{lr_scheduler}")
class CheckpointSaveManager:
class CheckpointManager:
"""StorageManagerContext"""
def __init__(
self,
ckpt_config,
model,
optimizer,
lr_scheduler,
model_config,
) -> None:
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None:
"""
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
upload mode, you must call wait_async_upload_finish at the end of the program to wait
for the asynchronous ckpt upload to complete.
@ -332,26 +381,95 @@ class CheckpointSaveManager:
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
self.stop_file_path = ckpt_config.stop_file_path
self.load_model_only_folder = ckpt_config.load_model_only_folder
self.feishu_address = feishu_address
self.storage_manager = get_storage_manager()
self.snapshot_counter = 0
self.load_optimizer = gpc.config.ckpt.load_optimizer
self.model = model
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.model_config = model_config
if self.stop_file_path and gpc.get_global_rank() == 0:
dir_path = os.path.dirname(self.stop_file_path)
if dir_path != "" and not os.path.exists(dir_path):
os.makedirs(dir_path)
with open(self.stop_file_path, "w", encoding="utf-8") as f:
f.write("0")
if ckpt_config.load_given_ckpt is False:
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
latest_ckpt_path = self.query_lastest_ckpt()
if latest_ckpt_path:
self.load_ckpt_folder = latest_ckpt_path
else:
# At this time, we have to load model init weights and train from step 0.
self.load_ckpt_folder = self.load_model_only_folder
else:
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
if gpc.is_rank_for_log():
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
if self.stop_file_path is None:
logger.warning("no set stop_file_path, quit_signal_handler is disable")
def quit_signal_handler(self, train_state) -> bool:
"""
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
all ranks will save ckpt and exit.
Negative integer step means save ckpt.
Positive integer step means save ckpt and quit.
Args:
train_state (TrainState):
Returns:
bool: whether to quit.
"""
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
if self.stop_file_path is None:
return now_break, now_save_ckpt, save_type
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
fcntl.flock(f, fcntl.LOCK_EX)
f.seek(0)
msg = f.read()
fcntl.flock(f, fcntl.LOCK_UN)
action_step = int(msg)
if action_step < 0 and abs(action_step) == train_state.step_count:
now_save_ckpt = True
if action_step > 0 and action_step == train_state.step_count:
now_break, now_save_ckpt = True, True
if action_step != 0 and gpc.is_rank_for_log():
msg = "Stop" if action_step > 0 else "Save"
action_step = abs(action_step)
if train_state.step_count <= action_step:
if self.feishu_address:
send_alert_message(
address=self.feishu_address,
message=f"training will {msg} at step_count {action_step}!\
now step_count is {train_state.step_count}",
)
return now_break, now_save_ckpt, save_type
def try_save_checkpoint(self, train_state):
if not self.enable_save_ckpt:
return
return False
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0:
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
if save_ckpts is False:
if quit_signal_handler is not None:
save_ckpts, save_type = quit_signal_handler(train_state)
save_ckpts = singal_save_ckpts
save_type = singal_save_type
if save_ckpts:
# Wait for the previous round of asynchronous upload storage to complete.
@ -361,9 +479,9 @@ class CheckpointSaveManager:
self.snapshot_counter = (self.snapshot_counter + 1) % 2
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
else:
save_ckpt_folder = self.save_ckpt_folder
save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count))
save_checkpoint(
self.save_checkpoint(
folder=save_ckpt_folder,
model=self.model,
optimizer=self.optimizer,
@ -372,7 +490,221 @@ class CheckpointSaveManager:
model_config=self.model_config,
)
return now_break
def wait_async_upload_finish(self):
"""wait for all checkpoint uploads to be completed"""
self.storage_manager.wait()
torch.distributed.barrier()
def query_latest_snapshot_step_boto3(self):
"""query_latest_snapshot_step_boto3
Returns:
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
"""
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
if len(ckpt_list) == 0:
return None, None
max_normal_step = 0
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
ckpt_list.sort(reverse=True)
for ckpt in ckpt_list:
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
for fn in fns_list:
if fn.endswith(".step"):
max_normal_step = ckpt
break
if max_normal_step != 0:
break
max_normal_step = ckpt_list[0]
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
max_step_0, max_step_1 = 0, 0
for ckpt in ckpt_list_1:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
for ckpt in ckpt_list_2:
ckpt = ckpt.strip("/")
if ckpt.endswith(".step"):
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
snap_step = max(max_step_0, max_step_1)
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
load_step = max(snap_step, max_normal_step)
return load_path, load_step
def query_latest_snapshot_step_local(self):
max_step, max_step_path = 0, None
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
for fn in files:
fn = fn.strip("/")
if fn.endswith(".step"):
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
# as an integrity flag.
step = int(fn.rsplit(".", maxsplit=1)[0])
if max_step < step:
max_step = step
max_step_path = root
return max_step_path, max_step
def query_lastest_ckpt(self):
latest_checkpoint = None
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
if self.save_ckpt_folder:
if self.save_ckpt_folder.startswith("boto3"):
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
elif self.save_ckpt_folder.startswith("local"):
latest_checkpoint, step = self.query_latest_snapshot_step_local()
else:
latest_checkpoint, step = None, 0
if latest_checkpoint is not None:
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
send_alert_message(
address=self.feishu_address,
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
)
else:
if gpc.is_rank_for_log():
send_alert_message(
address=self.feishu_address,
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
)
return latest_checkpoint
def try_load_model(self, current_time=""):
model_load_path = None
if self.load_ckpt_folder and self.load_model_only_folder:
raise ValueError(
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
if you only need to load model weights (for example starting an SFT task for the first time), \
set load_model_only_folder path, if you need to resume training from ckpt, \
set load_ckpt_folder or use default value \
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
)
if self.load_ckpt_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_ckpt_folder
elif self.load_model_only_folder:
if gpc.is_rank_for_log():
logger.info(
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = self.load_model_only_folder
else:
if gpc.is_rank_for_log():
logger.info(
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
# Loading model weights must be done before zero is initialized.
if model_load_path is not None:
load_model_checkpoint(folder=model_load_path, model=self.model)
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
"""Attempt to restore the training state of the last ckpt.
Args:
lr_scheduler (_LRScheduler): lr_scheduler object.
optimizer (Optimizer): optimizer object.
lr (float): learning rate.
train_state (dict): traing states.
train_dl (DataLoader): traning dataloader object
"""
if self.load_ckpt_folder is not None:
# load optimzier states.
if self.load_optimizer:
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
# load lr scheduler states.
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
# load training states.
load_context(self.load_ckpt_folder, train_dl, train_state)
# load dataloader sampler states.
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
if hasattr(train_state, "data_state_dict"):
train_dl.dataset.load_state_dict(
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
"""
Save checkpoint to the given folder path.
"""
start = time.time()
self.set_save_folder(folder, train_state.step_count)
torch.cuda.synchronize()
torch.distributed.barrier()
if gpc.is_rank_for_log():
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
timer("save-model").start()
save_model_checkpoint(folder=folder, model=model)
timer("save-model").stop()
timer("save-optimizer").start()
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()
if (
hasattr(train_state, "data_state_dict")
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
and gpc.get_local_rank(ParallelMode.PIPELINE) == 0
):
llm_save(
os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"),
saved_obj=train_state.data_state_dict,
)
if gpc.is_rank_for_log():
scheduler_states = scheduler.state_dict()
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
if hasattr(train_state, "batch_sampler") and not isinstance(
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
):
sampler_state = train_state.batch_sampler.state_dict()
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
if model_config is not None:
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
torch.distributed.barrier()
if gpc.is_rank_for_log():
timer.log(["save-model", "save-optimizer"], logger=logger)
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
if self.storage_manager.async_mode is False:
llm_save(
os.path.join(folder, f"{train_state.step_count}.step"),
saved_obj=dict({"step": train_state.step_count}),
)
def set_save_folder(self, folder, step):
self.storage_manager.latest_save_folder = folder
self.storage_manager.latest_save_step = step

View File

@ -1,15 +1,13 @@
import os
import time
from collections import OrderedDict
from functools import partial
from functools import partial, reduce
from typing import Any, Dict, List, Tuple
import pyecharts
import torch
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.solver.pipeline_utils import partition_uniform
from internlm.core.naive_amp import NaiveAMPModel
mb = 1024 * 1024
@ -107,6 +105,8 @@ class SimpleMemState:
"""
Update the total memory usage of the model and sub-models.
"""
self._total_mem = self._layer_mem
for stat in self.sub_model_stats.values():
# Update sub-model status first.
stat.update_total_memory()
@ -169,6 +169,39 @@ class SimpleMemState:
return {"name": self.layer_name, "children": children}
class ActivationMemState:
"""
Activation Memory State
"""
def __init__(self, num_chunks: int) -> None:
self._num_chunks = num_chunks
self.inited: List[bool] = [False for _ in range(num_chunks)]
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)]
@property
def total_mem(self) -> int:
return sum(state.total_mem for state in self.states)
def dump(self, prefix: str = "") -> str:
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states])
def to_json(self, base: int = 1024 * 1024) -> List:
return [state.to_json(base) for state in self.states]
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]:
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1
if num_chunks > 1:
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model])
else:
model = model.model if isinstance(model, NaiveAMPModel) else model
return model, num_chunks
class SimpleMemoryProfiler:
"""
A memory profiler for a llm model.
@ -177,7 +210,7 @@ class SimpleMemoryProfiler:
model (torch.nn.Module): The model to profile.
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
log_file (str): The file to write the memory state information to.
activation_config (List[str], optional): The list of activation layers to track. Defaults to None.
total_steps: number of steps to trace.
"""
def __init__(
@ -186,9 +219,8 @@ class SimpleMemoryProfiler:
optimizer: torch.optim.Optimizer,
log_folder: str,
total_steps: int = 5,
activation_config: List[str] = None,
):
self._model = model
self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
self._optimizer = optimizer
self._log_folder = log_folder
self._remaining_steps = total_steps
@ -197,17 +229,20 @@ class SimpleMemoryProfiler:
self._record_start_time = time.time()
# For activation memory state.
self._activation_config = activation_config
self._activation_mem_inited: bool = False
self._activation_mem: int = 0
self._activation_max_count = 0
self._activation_base_mem: SimpleMemState = SimpleMemState("activations")
self._activation_mem_max: int = 0
self._activation_base_mems = ActivationMemState(self._num_model_chunks)
# Check or create log folder
os.makedirs(self._log_folder, exist_ok=True)
# Register activation memory tracking hooks
self._register_activation_trace_hooks()
if self._num_model_chunks > 1:
for chunk_id in range(self._num_model_chunks):
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
else:
self._register_activation_trace_hooks(0, self._model)
# Calculate static parameter cuda memory
self._param_mem_state = SimpleMemState("param_mem")
@ -221,7 +256,7 @@ class SimpleMemoryProfiler:
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
# Generate the first memory record
self.point(create=True)
self.point(with_options="params,grads,os_params", create=True)
def point(self, with_options: str = "", create: bool = False) -> None:
"""
@ -272,7 +307,7 @@ class SimpleMemoryProfiler:
if "os_state" in options:
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
if "activation_base" in options:
layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump()
layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump()
# Write memory state information to log file
file_mode = "w" if create else "a"
@ -315,14 +350,14 @@ class SimpleMemoryProfiler:
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
"os_memory_sunburst",
)
self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst")
self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
# Generate summary sunburst chart
summary_sunburst_data = [
{"name": "params", "value": self._param_mem_state.total_mem // mb},
{"name": "grads", "value": self._grad_mem_state.total_mem // mb},
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
{"name": "activation", "value": self._activation_base_mem.total_mem // mb},
{"name": "activation", "value": self._activation_mem_max // mb},
]
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
@ -337,12 +372,13 @@ class SimpleMemoryProfiler:
{},
{
"r0": "10%",
"r": "40%",
"r": "35%",
"itemStyle": {"borderWidth": 3},
"label": {"align": "left"},
},
{"r0": "40%", "r": "65%", "label": {"align": "left"}},
{"r0": "65%", "r": "80%", "label": {"align": "left"}},
{"r0": "35%", "r": "55%", "label": {"align": "left"}},
{"r0": "55%", "r": "70%", "label": {"align": "left"}},
{"r0": "70%", "r": "80%", "label": {"align": "left"}},
{"r0": "80%", "r": "90%", "label": {"align": "left"}},
{
"r0": "90%",
@ -357,7 +393,14 @@ class SimpleMemoryProfiler:
f"{self._log_folder}/{name}.html"
)
def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None:
def _inner_activation_trace_hook(
self,
chunk_id: int,
layer_name: str,
model: Any,
inputs: Any,
output: torch.Tensor,
) -> None:
"""
Hook function to trace the activation memory usage for a inner layer.
@ -373,13 +416,15 @@ class SimpleMemoryProfiler:
del model, inputs
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
if self._stoped or self._activation_mem_inited:
if self._stoped or self._activation_base_mems.inited[chunk_id]:
return
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False)
self._activation_base_mems.states[chunk_id].add(
layer_name, output.element_size() * output.nelement(), flush=False
)
def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None:
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
"""
Hook function to trace the activation memory usage for a forward pass.
@ -398,23 +443,24 @@ class SimpleMemoryProfiler:
return
# Check if the activation memory has been initialized
if self._activation_mem_inited is False:
if self._activation_base_mems.inited[chunk_id] is False:
self._activation_base_mems.inited[chunk_id] = True
# Update the total memory of the activation base memory state
self._activation_base_mem.update_total_memory()
self._activation_base_mems.states[chunk_id].update_total_memory()
# Set with_options to "activation_base" to include activation_base_layout in the memory dump
self._activation_mem_inited = True
with_options = "activation_base"
else:
with_options = ""
# Accumulate activation memory usage for each forward pass
self._activation_mem += self._activation_base_mem.total_mem
# Update activation max count
if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count:
self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem
self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem
if self._activation_mem > self._activation_mem_max:
self._activation_mem_max = self._activation_mem
# Trigger a memory record
self.point()
self.point(with_options)
def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None:
def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None:
"""
Hook function to trace the activation memory usage for a backward pass.
@ -432,37 +478,28 @@ class SimpleMemoryProfiler:
return
# Release activation memory usage for each backward pass
self._activation_mem -= self._activation_base_mem.total_mem
self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem
# Trigger a memory record
self.point()
def _register_activation_trace_hooks(self) -> None:
def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None:
"""
Register activation trace hooks for the model and each submodule in the model.
"""
# Register inner activation trace hooks for each submodule in the model
for layer_name in self._activation_config:
# Register a hook for every activation
model = self._model
sub_models = layer_name.split(".")
# Get the target sub-model
for sub_model_name in sub_models:
try:
model = model.get_submodule(sub_model_name)
except AttributeError:
model = None
break
for layer_name, sub_model in model_chunk.named_modules():
# Register the hook
if model is not None:
model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name))
if len(sub_model._modules) != 0:
continue # TODO: in some special cases, we may need some additional configuration to correct
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name))
# Register a forward hook for the main model to track activation memory usage
self._model.register_forward_hook(self._activation_trace_hook_forward)
model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id))
# Register a backward hook for the main model to release activation memory usage
self._model.register_full_backward_hook(self._activation_tarce_hook_backward)
model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id))
def _calc_tensor_memory(
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
@ -554,48 +591,6 @@ class SimpleMemoryProfiler:
self._calc_tensor_memory(root_stat, named_tensors)
def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]:
# TODO: support interleaved pipeline scheduling.
assert num_chunks == 1, "Only support num_chunks == 1"
if gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pipeline_size = 1
pipeline_rank = 0
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
parts = all_parts[pipeline_rank]
start, end = parts[0]
num_blocks = end - start
block_conf_tmpl = [
"mixer.rotary_emb",
"mixer.Wqkv",
"mixer.inner_attn",
"mixer.inner_cross_attn",
"mixer.out_proj",
# "dropout1", # skip when dropout_selective_checkpoint is True
# "dropout2", # skip when dropout_selective_checkpoint is True
"norm1",
"norm2",
"mlp.w1",
"mlp.w2",
"mlp.w3",
]
block_conf = []
for block_id in range(num_blocks):
block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl]
# We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning.
# If they don't exist, they will be automatically ignored when registering activation trace hooks.
activation_conf = ["embedding", "norm", "head"] + block_conf
return activation_conf
if __name__ == "__main__":
class SimpleModel(torch.nn.Module):
@ -635,32 +630,39 @@ if __name__ == "__main__":
return output
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor:
if _num_chunks > 1:
_output = _input
for _model_chunk in _model_chunks:
_output = _model_chunk(_output)
else:
_output = _model_chunks(_input)
return _output
# num_chunks config
_num_chunks = 1
# init model and optimizer
_model: torch.nn.Module = SimpleModel()
if _num_chunks > 1:
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)]
_model = torch.nn.ModuleList(_chunks).cuda()
else:
_model: torch.nn.Module = SimpleModel().cuda()
_optimizer = torch.optim.Adam(_model.parameters())
# create activation config for simple model layer by layer.
activation_configs = [
# model level 0
"layer1",
"layer2",
"layer3",
# model level 1
"layer2.layer1",
"layer2.layer3",
]
_model.modules()
# init profiler
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs)
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1)
_optimizer.zero_grad()
x1 = torch.randn((128, 5120))
x2 = torch.randn((128, 5120))
out1 = _model(x1)
out2 = _model(x2)
# inputs
x1 = torch.randn((128, 5120)).cuda()
x2 = torch.randn((128, 5120)).cuda()
# forward
out1 = _simple_schedule(_num_chunks, _model, x1)
out2 = _simple_schedule(_num_chunks, _model, x2)
# backward
out1.mean().backward()
out2.mean().backward()

View File

@ -15,8 +15,6 @@ from asyncio.tasks import ALL_COMPLETED
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Union
import boto3
import botocore
import torch
import torch.distributed as dist
@ -24,6 +22,13 @@ from internlm.core.context import global_context as gpc
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
try:
import boto3
import botocore
except ImportError:
pass
logger = get_logger(__file__)
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
@ -234,13 +239,13 @@ class Boto3Client(StorageClient):
"""
paginator = handler.client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
folder_name_list = []
for page in pages:
for obj in page["Contents"]:
fp: str = obj["Key"]
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
return folder_name_list
if "Contents" in page:
for obj in page["Contents"]:
pth: str = obj["Key"]
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
@ -391,6 +396,11 @@ class StorageManager(metaclass=SingletonMeta):
self.tmp_local_folder = tmp_local_folder
self.async_mode = async_mode
self.has_warning = False
self._async_loop = None
self._thread_pool = None
self.latest_save_folder = None
self.latest_save_step = 0
self.async_task_peeding = False
if enable_save and self.async_mode:
self._async_loop = asyncio.new_event_loop()
@ -485,6 +495,7 @@ class StorageManager(metaclass=SingletonMeta):
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
self.async_task_peeding = True
else:
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
self.upload_count += 1
@ -523,23 +534,22 @@ class StorageManager(metaclass=SingletonMeta):
pass
async def _sync_tasks(self) -> Awaitable[None]:
if not self._async_stack:
return
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
for task in self._async_stack:
try:
task.exception()
except InvalidStateError:
continue
except Exception as e:
file_id = len(self._exception_list)
self._exception_list.append((e, file_id))
logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}")
self._async_stack.clear()
if self._async_stack:
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
count = 0
while self._async_stack:
t = self._async_stack[0]
try:
e = t.exception()
if e:
self._exception_list.append((e, count))
logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}")
# raise e
count += 1
self._async_stack.pop(0)
except InvalidStateError:
# Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception
pass
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
"""
@ -559,11 +569,14 @@ class StorageManager(metaclass=SingletonMeta):
if not self.async_mode:
return
if not self.async_task_peeding:
return
if self._async_loop:
self._async_loop.run_until_complete(self._sync_tasks())
if self._exception_list:
for file_id, error_msg in self._exception_list:
for error_msg, file_id in self._exception_list:
logger.error(
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
f"failed on step {self.upload_count}: {error_msg}"
@ -577,10 +590,16 @@ class StorageManager(metaclass=SingletonMeta):
self._del_tmp_folder()
self._exception_list.clear()
self._to_be_del_files.clear()
self.async_task_peeding = False
if gpc.is_rank_for_log():
logger.info("all async uploads succeeded!")
self.upload_count += 1
if self.async_mode:
self.save(
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
saved_obj=dict({"step": self.latest_save_step}),
async_upload=False,
)
storage_manager: StorageManager = None

View File

@ -11,10 +11,6 @@ from torch.utils.tensorboard import SummaryWriter
from internlm.core.context import global_context as gpc
def copy_ignore_folder(source_path, target_path):
os.system(f"cp -r {source_path}/* {target_path}/")
def tb_save_run_info(writer, config_lines, global_step=0):
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
lines = []
@ -44,7 +40,8 @@ def init_tb_writer(
if gpc.get_global_rank() == 0:
if resume_tb_folder is not None:
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
copy_ignore_folder(resume_tb_folder, tb_folder)
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
os.system(f"chmod -R +w {tb_folder}/")
else:
logger.info(f"Login tensorboard logs to: {tb_folder}")

623
train.py
View File

@ -5,99 +5,48 @@ import socket
import time
import traceback
from functools import partial
from typing import Iterable
import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader
import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.scheduler import SchedulerMetricHook
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
from internlm.data.dataset import get_dataset_dict
from internlm.data.dummy_dataset import RandomDataset
from internlm.data.packed_dataset import (
PackedDataset,
PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.model.moe import create_moe_param_groups, has_moe_layers
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
from internlm.monitor import initialize_monitor_manager, send_alert_message
from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
initialize_distributed_env,
initialize_llm_profile,
initialize_model,
initialize_optimizer,
load_new_batch,
record_current_batch_training_metrics,
)
from internlm.utils.common import (
BatchSkipper,
get_master_node,
get_megatron_flops,
launch_time,
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import (
CheckpointSaveManager,
load_context,
load_model_checkpoint,
load_optimizer_checkpoint,
load_sampler,
load_scheduler,
)
from internlm.utils.parallel import (
get_parallel_log_file_name,
is_no_pp_or_last_stage,
sync_model_param_with_ep,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.simple_memory_profiler import (
SimpleMemoryProfiler,
build_activation_config,
)
from internlm.utils.model_checkpoint import CheckpointManager
from internlm.utils.parallel import get_parallel_log_file_name
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
from internlm.utils.writer import Writer
# global llm logger
logger = get_logger(__file__)
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
"""
Initialize distributed environment for distributed training.
Args:
config (str): Config file path.
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
master_port (str): The master port for distributed training. 8888 by default.
seed (int, optional): Specified random seed for every process. 1024 by default.
"""
torch.cuda.empty_cache()
if launcher == "torch":
internlm.launch_from_torch(config=config, seed=seed)
elif launcher == "slurm":
internlm.launch_from_slurm(
config=config,
host=get_master_node(),
port=master_port,
seed=seed,
)
else:
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
def initialize_llm_logger(start_time: str):
"""
Initialize customed uniscale logger.
@ -118,338 +67,14 @@ def initialize_llm_logger(start_time: str):
return uniscale_logger
def initialize_model():
"""
Initialize model.
Returns: The neural network model to be trained or evaluated.
"""
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
if isinstance(model, nn.ModuleList):
model = nn.ModuleList(
[
NaiveAMPModel(
model=_m,
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
for _m in model
]
)
else:
model = NaiveAMPModel(
model=model,
output_to_fp32=is_no_pp_or_last_stage(),
dtype=gpc.config.model.get("dtype", torch.half),
sync_buffer=False,
)
# This sync is very important, cause the model weights kept in optimizer are copied
# from the origin parameters in the memory, so we should make sure the dp sync
# does not influence the model weights in optimizer be different with the origin parameters.
sync_model_param_with_ep(model)
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
# the same across tensor parallelism.
sync_model_param_within_tp(model)
return model
def get_train_data_loader(num_worker: int = 0):
"""
Generate and return the training data loader.
Returns: A tuple of (train_dl, dataset_types).
"""
# Get the dataset types
dataset_types = None
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder
if not train_folder:
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
if data_cfg.pack_sample_into_one:
train_ds = PackedDatasetWithoutCuSeqlen(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = PackedDataset(
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
)
else:
train_ds = get_packed_dataset_without_short_length(
folder=data_cfg.train_folder,
packed_length=data_cfg.packed_length,
max_length_per_sample=data_cfg.seq_len,
show_progress=dist.get_rank() == 0,
min_length=data_cfg.min_length,
min_length_dict=data_cfg.get("min_length_dict", {}),
pack_into_one_sample=data_cfg.pack_sample_into_one,
)
# partition already completed
# assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen))
if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)):
datasets = [train_ds]
else:
datasets = train_ds.datasets
# Create the training dataset sampler
train_sampler = StaticBatchSampler(
datasets,
batch_size=data_cfg.micro_num,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=1024,
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
# Create the training data loader
train_dl = DataLoader(
dataset=train_ds,
batch_sampler=train_sampler,
num_workers=num_worker,
pin_memory=True,
collate_fn=train_collate_fn,
persistent_workers=True,
)
return train_dl, dataset_types
def get_validation_data_loader(num_worker: int = 0):
"""Generate and return the validation data loader."""
data_cfg = gpc.config.data
if not data_cfg.valid_folder:
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
else:
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
if not isinstance(val_ds, dict):
val_ds = {"val": val_ds}
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
val_dls = {}
for val_name, ds in val_ds.items():
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped
batch_size = min(
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
)
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.") # pylint: disable=W1203
continue
val_dls[val_name] = get_dpsampler_dataloader(
ds, shuffle=False, num_workers=num_worker, batch_size=batch_size, collate_fn=val_collate_fn, drop_last=True
) # drop_last=True, otherwise it may cause problems in the last batch
if gpc.is_rank_for_log():
logger.info( # pylint: disable=W1203
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
f"samples {str(len(val_dls[val_name]))}."
)
return val_dls
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
Args:
train_dl (torch.utils.data.DataLoader): Dataloader for training.
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
train_state (TrainState): Current training state.
Returns: A batch data and the updated train_iter.
"""
timer("batch-gen").start()
try:
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
next(train_state.batch_sampler_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
next(train_state.batch_sampler_iter)
train_state.num_consumed_samples_in_epoch = 0
timer("batch-gen").stop()
return batch, train_iter
def initialize_optimizer(model: nn.Module):
"""
Initialize optimizer.
Args:
model (torch.nn.Module): Your model instance to be trained or evaluated.
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
adam_cfg = gpc.config.adam
if gpc.config.model.num_experts > 1:
params = create_moe_param_groups(model, adam_cfg.weight_decay)
else:
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
naive_optimizer = torch.optim.AdamW(
params=params,
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,
)
has_moe = has_moe_layers(model)
optimizer = HybridZeroOptimizer(
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
has_moe=has_moe,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
return optimizer, beta2_scheduler, lr_scheduler
def record_current_batch_training_metrics(
get_tflops_func,
logger,
writer,
success_update,
batch_count,
batch,
train_state,
optimizer,
beta2_scheduler,
trainer,
start_time,
loss,
grad_norm,
metric,
update_panel,
):
"""
Print some training metrics of current batch.
"""
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
if success_update in (0, True):
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
if is_no_pp_or_last_stage():
acc_perplex = metric.get_metric()
if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"]
if hasattr(trainer.engine.optimizer, "grad_scaler"):
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
num_tokens_in_batch = batch[1].nelement()
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
tk_per_gpu = 0
tk_per_gpu = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
/ (time.time() - start_time),
2,
)
tflops = get_tflops_func((time.time() - start_time))
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"lr": lr,
"loss_scale": scaler,
"grad_norm": grad_norm,
}
infos["micro_num"] = len(batch[1])
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
infos["largest_length"] = max_length_in_batch # the longest input
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
infos["smallest_batch"] = min_samples_in_batch
infos["adam_beta2"] = beta2_scheduler.get_beta2()
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
infos["fwd_bwd_time"] = fwd_bwd_time
for key, value in acc_perplex.items():
infos[key] = value
line = ""
for key, value in infos.items():
line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if update_panel:
logger.info(
line,
extra={
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
)
else:
logger.info(line)
# if loss spike occurs, send alert info to feishu
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
def main(args):
# init setting
skip_batches = gpc.config.data.skip_batches
total_steps = gpc.config.data.total_steps
valid_every = gpc.config.data.valid_every
load_optimizer = gpc.config.ckpt.load_optimizer
label_smoothing = gpc.config.loss.label_smoothing
lr = gpc.config.adam.lr
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
@ -485,32 +110,19 @@ def main(args):
enable_tb=gpc.config.enable_tb,
)
model_load_path = None
if load_resume_ckpt_folder is not None:
logger.info( # pylint: disable=W1203
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = load_resume_ckpt_folder
elif load_model_only_folder is not None:
logger.info( # pylint: disable=W1203
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
f"{socket.gethostname()}==========="
)
model_load_path = load_model_only_folder
else:
logger.info( # pylint: disable=W1203
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
)
# initialize and resume train state
train_state = TrainState(gpc.config)
# initialize model
model = initialize_model()
ckpt_manager = CheckpointManager(
ckpt_config=gpc.config.ckpt,
model=model,
model_config=gpc.config.model,
feishu_address=gpc.config.alert_address,
)
# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
@ -520,30 +132,12 @@ def main(args):
train_state.init_batch_sampler(train_dl)
# Loading model weights must be done before zero is initialized.
if model_load_path is not None:
load_model_checkpoint(folder=model_load_path, model=model)
ckpt_manager.try_load_model(current_time)
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
# Loading other persistent training states.
if load_resume_ckpt_folder is not None:
# load lr scheduler states.
load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
# load training states.
load_context(load_resume_ckpt_folder, train_dl, train_state)
# load dataloader sampler states.
load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler)
# load optimzier states.
if load_optimizer:
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
ckpt_save_manager = CheckpointSaveManager(
ckpt_config=gpc.config.ckpt,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
model_config=gpc.config.model,
)
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
@ -579,12 +173,11 @@ def main(args):
# initialize simple memory profiler
if args.profiling:
memory_profiler = SimpleMemoryProfiler(
model.model,
model,
optimizer.optim,
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
activation_config=build_activation_config(gpc.config.model.num_layers),
)
else:
memory_profiler = None
@ -597,86 +190,85 @@ def main(args):
# transfer the train data loader into train data iterator
train_iter = iter(train_dl)
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
start_time = time.time()
timer("one-batch").start()
start_time = time.time()
timer("one-batch").start()
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
# record the consumed samples in training
train_state.batch_count = batch_count
train_state.num_consumed_samples_in_epoch += len(batch[1])
if batch_skipper(batch_count): # skip this batch
if gpc.is_rank_for_log():
logger.info(f"Skip batch count:`{batch_count}`...")
timer("one-batch").stop()
continue
# zero the grads of parameters
trainer.zero_grad()
# process data
if batch[0].get("type_ids", None) is not None:
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
# do forward and backward
timer("fwd-bwd").start()
_, _, loss, moe_loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step()
assert trainer_result is not None
success_update, grad_norm_groups = trainer_result
if success_update: # update parameters successfully
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,
message=f"Warning: skip parameter update at step {batch_count}.",
)
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
logger=logger,
writer=writer,
success_update=success_update,
batch_count=batch_count,
batch=batch,
train_state=train_state,
optimizer=optimizer,
beta2_scheduler=beta2_scheduler,
trainer=trainer,
start_time=start_time,
loss=loss,
moe_loss=moe_loss,
grad_norm=np.array(grad_norm_groups),
metric=metric,
update_panel=uniscale_logger is not None,
)
# record the consumed samples in training
train_state.batch_count = batch_count
train_state.num_consumed_samples_in_epoch += len(batch[1])
if batch_skipper(batch_count): # skip this batch
if gpc.is_rank_for_log():
logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203
timer("one-batch").stop()
continue
# zero the grads of parameters
trainer.zero_grad()
type_ids = batch[0].pop("type_ids", None)
# process data
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
if type_ids is not None:
metric.set_current_type_ids(type_ids=type_ids)
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)
trainer_result = trainer.step()
assert trainer_result is not None
success_update, grad_norm_groups = trainer_result
if success_update: # update parameters successfully
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
)
# calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics(
get_tflops_func=get_tflops_func,
logger=logger,
writer=writer,
success_update=success_update,
batch_count=batch_count,
batch=batch,
train_state=train_state,
optimizer=optimizer,
beta2_scheduler=beta2_scheduler,
trainer=trainer,
start_time=start_time,
loss=loss,
grad_norm=np.array(grad_norm_groups),
metric=metric,
update_panel=uniscale_logger is not None,
)
timer("one-batch").stop()
# evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0:
with switch_sequence_parallel_mode():
# evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0:
evaluate_on_val_dls(
trainer=trainer,
val_dls=val_dls,
@ -686,14 +278,19 @@ def main(args):
update_panel=uniscale_logger is not None,
)
if memory_profiler is not None:
memory_profiler.step()
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples
now_break = ckpt_manager.try_save_checkpoint(train_state)
if now_break:
break
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples
ckpt_save_manager.try_save_checkpoint(train_state)
if memory_profiler is not None:
memory_profiler.step()
ckpt_save_manager.wait_async_upload_finish()
if batch_count % 2 == 0:
prof.step()
ckpt_manager.wait_async_upload_finish()
if __name__ == "__main__":