mirror of https://github.com/InternLM/InternLM
merge branch 'feature_add_moe' into feature_add_moe_data
commit
0e2eb90d22
|
|
@ -0,0 +1,152 @@
|
|||
JOB_NAME = "7b_train"
|
||||
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 16
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||
# Ckpt folder format:
|
||||
# fs: 'local:/mnt/nfs/XXX'
|
||||
SAVE_CKPT_FOLDER = "local:llm_ckpts"
|
||||
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
|
||||
|
||||
# boto3 Ckpt folder format:
|
||||
# import os
|
||||
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
|
||||
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
|
||||
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
|
||||
CHECKPOINT_EVERY = 50
|
||||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
|
||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki"
|
||||
VALID_FOLDER = "/mnt/petrelfs/share_data/llm_data/0623_scratch_tokenized_filtered/train/en/enwiki"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
packed_length = 2 * SEQ_LEN,
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
valid_every=50000,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
train_folder=TRAIN_FOLDER,
|
||||
valid_folder=VALID_FOLDER,
|
||||
)
|
||||
|
||||
grad_scaler = dict(
|
||||
fp16=dict(
|
||||
# the initial loss scale, defaults to 2**16
|
||||
initial_scale=2**16,
|
||||
# the minimum loss scale, defaults to None
|
||||
min_scale=1,
|
||||
# the number of steps to increase loss scale when no overflow occurs
|
||||
growth_interval=1000,
|
||||
),
|
||||
# the multiplication factor for increasing loss scale, defaults to 2
|
||||
growth_factor=2,
|
||||
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
||||
backoff_factor=0.5,
|
||||
# the maximum loss scale, defaults to None
|
||||
max_scale=2**24,
|
||||
# the number of overflows before decreasing loss scale, defaults to 2
|
||||
hysteresis=2,
|
||||
)
|
||||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
zero_overlap_communication=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
|
||||
loss = dict(
|
||||
label_smoothing=0,
|
||||
moe_loss_coeff=0.1,
|
||||
)
|
||||
|
||||
adam = dict(
|
||||
lr=1e-4,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.95,
|
||||
adam_beta2_c=0,
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.01,
|
||||
)
|
||||
|
||||
lr_scheduler = dict(
|
||||
total_steps=data["total_steps"],
|
||||
init_steps=0, # optimizer_warmup_step
|
||||
warmup_ratio=0.01,
|
||||
eta_min=1e-5,
|
||||
last_epoch=-1,
|
||||
)
|
||||
|
||||
beta2_scheduler = dict(
|
||||
init_beta2=adam["adam_beta2"],
|
||||
c=adam["adam_beta2_c"],
|
||||
cur_iter=-1,
|
||||
)
|
||||
|
||||
model = dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=NUM_ATTENTION_HEAD,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_layers=NUM_LAYER,
|
||||
mlp_ratio=MLP_RATIO,
|
||||
apply_post_layer_norm=False,
|
||||
dtype="torch.bfloat16",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
sequence_parallel=False,
|
||||
num_experts=4,
|
||||
moe_use_residual=False,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||
so parameters will be divided within the range of dp.
|
||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||
pipeline parallel (dict):
|
||||
1. size: int, the size of pipeline parallel.
|
||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
# zero1=4,
|
||||
pipeline=dict(size=4, interleaved_overlap=False),
|
||||
# tensor=dict(size=4),
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
cudnn_benchmark = False
|
||||
|
|
@ -115,8 +115,9 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
loss += moe_loss
|
||||
moe_loss /= scale_loss
|
||||
loss /= scale_loss
|
||||
loss += moe_loss
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
|
|
@ -127,7 +128,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if not return_loss:
|
||||
loss = None
|
||||
|
||||
return output, loss
|
||||
return output, loss, moe_loss
|
||||
|
||||
def forward_backward_step(
|
||||
self,
|
||||
|
|
@ -166,6 +167,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
data, label = batch_data
|
||||
|
||||
loss = 0 if return_loss else None
|
||||
moe_loss = 0 if return_loss else None
|
||||
outputs = []
|
||||
labels = []
|
||||
|
||||
|
|
@ -180,12 +182,14 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
|
||||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_output, _loss, _moe_loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
loss += _loss
|
||||
moe_loss += _moe_loss
|
||||
|
||||
if return_output_label:
|
||||
outputs.append(_output)
|
||||
labels.append(_label)
|
||||
|
|
@ -193,4 +197,4 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if not return_output_label:
|
||||
outputs, labels = None, None
|
||||
|
||||
return outputs, labels, loss
|
||||
return outputs, labels, loss, moe_loss
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from contextlib import contextmanager
|
|||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
|
||||
import internlm.core.communication as comm
|
||||
from internlm.core.context import ParallelMode
|
||||
|
|
@ -130,7 +131,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.dtype = dtype
|
||||
self._hooks = scheduler_hooks
|
||||
|
||||
self.tensor_shape = (
|
||||
self._tensor_shape = (
|
||||
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
|
||||
)
|
||||
|
||||
|
|
@ -146,6 +147,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
# cache for the batch data
|
||||
self.batch_data = None
|
||||
|
||||
@property
|
||||
def tensor_shape(self) -> torch.Size:
|
||||
return self._tensor_shape
|
||||
|
||||
@tensor_shape.setter
|
||||
def tensor_shape(self, tensor_shape: torch.Size):
|
||||
self._tensor_shape = tensor_shape
|
||||
|
||||
def pre_processing(self, engine):
|
||||
types = set()
|
||||
|
||||
|
|
@ -239,7 +248,16 @@ class PipelineScheduler(BaseScheduler):
|
|||
"""
|
||||
return step_id
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||
def _forward_step(
|
||||
self,
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=True,
|
||||
accum_loss=None,
|
||||
accum_moe_loss=None,
|
||||
moe_loss_coeff=1.0,
|
||||
):
|
||||
"""
|
||||
Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
|
|
@ -251,6 +269,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||
return_output_label (bool, optional): Whether returns output labels.
|
||||
accum_loss (optional): Where accumulated loss stores.
|
||||
accum_moe_loss (optional): Where accumulated moe loss stores.
|
||||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
||||
pipeline stage.
|
||||
|
|
@ -259,7 +278,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||
|
||||
self._call_hooks("before_forward", data)
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
output_obj, moe_losses = self._call_engine(engine.model, data)
|
||||
self._call_hooks("after_forward", output_obj)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
|
@ -275,9 +294,13 @@ class PipelineScheduler(BaseScheduler):
|
|||
accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
return output_obj
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
|
||||
return output_obj, moe_loss
|
||||
|
||||
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss=None):
|
||||
"""
|
||||
Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
|
|
@ -311,10 +334,18 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
||||
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj)
|
||||
if moe_loss is None:
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj)
|
||||
else:
|
||||
engine.backward_by_grad(output_obj, output_obj_grad)
|
||||
else:
|
||||
engine.backward_by_grad(output_obj, output_obj_grad)
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj + moe_loss)
|
||||
else:
|
||||
# scale the latent loss
|
||||
moe_loss = moe_loss * engine.optimizer.loss_scale
|
||||
engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
|
||||
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
|
|
@ -329,7 +360,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return input_obj_grad
|
||||
|
||||
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
|
||||
def _forward_only_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
|
||||
"""
|
||||
This function performs forward only computation process. The scheduling of microbatches is similar to the
|
||||
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
|
||||
|
|
@ -356,6 +387,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
||||
else None
|
||||
)
|
||||
accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||
|
||||
# Used for tensor meta information communication
|
||||
forward_recv_shapes = self.tensor_shape
|
||||
|
|
@ -376,12 +408,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
input_obj = None
|
||||
|
||||
# Perform forward computation
|
||||
output_obj = self._forward_step(
|
||||
output_obj, _ = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
|
@ -392,10 +426,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
return output, label, accum_loss
|
||||
if accum_loss is not None:
|
||||
accum_loss += accum_moe_loss
|
||||
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True, moe_loss_coeff=1.0):
|
||||
"""
|
||||
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
|
||||
It consists of three stages: warmup, 1F1B, and cooldown.
|
||||
|
|
@ -441,12 +479,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
moe_losses = []
|
||||
return_tensors = []
|
||||
accum_loss = (
|
||||
torch.zeros(1, device=get_current_device())
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
|
||||
else None
|
||||
)
|
||||
accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||
|
||||
# Used for tensor meta information communication
|
||||
forward_recv_shapes = self.tensor_shape
|
||||
|
|
@ -468,12 +508,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
input_obj = None
|
||||
|
||||
# Perform forward computation
|
||||
output_obj = self._forward_step(
|
||||
output_obj, moe_loss = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
|
@ -493,6 +535,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
moe_losses.append(moe_loss)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
|
|
@ -512,12 +555,14 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Run 1F1B in steady state.
|
||||
for i in range(num_1f1b_micropairs):
|
||||
# Perform forward computation
|
||||
output_obj = self._forward_step(
|
||||
output_obj, moe_loss = self._forward_step(
|
||||
engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss,
|
||||
accum_moe_loss=accum_moe_loss,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
|
|
@ -533,13 +578,15 @@ class PipelineScheduler(BaseScheduler):
|
|||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
moe_losses.append(moe_loss)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
moe_loss = moe_losses.pop(0)
|
||||
|
||||
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
|
||||
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)
|
||||
|
||||
if i == (num_1f1b_micropairs - 1):
|
||||
input_obj = None
|
||||
|
|
@ -563,6 +610,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
for i in range(num_warmup_microsteps):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
moe_loss = moe_losses.pop(0)
|
||||
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
output_obj_grad = comm.recv_backward(
|
||||
|
|
@ -574,17 +622,25 @@ class PipelineScheduler(BaseScheduler):
|
|||
output_obj_grad = None
|
||||
|
||||
input_obj_grad = self._backward_step(
|
||||
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
|
||||
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad, moe_loss
|
||||
)
|
||||
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {accum_moe_loss.item()}")
|
||||
|
||||
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
||||
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
|
||||
return output, label, accum_loss
|
||||
if accum_loss is not None:
|
||||
accum_loss += accum_moe_loss
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
||||
def forward_backward_step(
|
||||
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
|
||||
):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
|
|
@ -596,7 +652,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss, loss), loss and label could be None.
|
||||
"""
|
||||
|
||||
assert (
|
||||
|
|
@ -607,9 +663,9 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.load_batch(engine, data_iter)
|
||||
|
||||
if forward_only:
|
||||
return self._forward_only_step(engine, return_loss, return_output_label)
|
||||
return self._forward_only_step(engine, return_loss, return_output_label, moe_loss_coeff)
|
||||
else:
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label)
|
||||
return self._forward_backward_step(engine, return_loss, return_output_label, moe_loss_coeff)
|
||||
|
||||
|
||||
class InterleavedPipelineScheduler(PipelineScheduler):
|
||||
|
|
@ -676,21 +732,35 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
self._accum_loss = None
|
||||
self._accum_moe_loss = None
|
||||
self._return_tensors = None
|
||||
self._input_objs = [[] for _ in range(num_chunks)]
|
||||
self._output_objs = [[] for _ in range(num_chunks)]
|
||||
self._output_obj_grads = [[] for _ in range(num_chunks)]
|
||||
self._moe_losses = [[] for _ in range(num_chunks)]
|
||||
|
||||
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
|
||||
self._output_obj_shapes = [None for _ in range(num_chunks)]
|
||||
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
|
||||
|
||||
@property
|
||||
def tensor_shape(self) -> torch.Size:
|
||||
return self._tensor_shape
|
||||
|
||||
@tensor_shape.setter
|
||||
def tensor_shape(self, tensor_shape: torch.Size):
|
||||
self._tensor_shape = tensor_shape
|
||||
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
|
||||
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
|
||||
|
||||
def _clear_state(self) -> None:
|
||||
self._accum_loss = None
|
||||
self._accum_moe_loss = None
|
||||
self._return_tensors = None
|
||||
self._input_objs = [[] for _ in range(self._num_chunks)]
|
||||
self._output_objs = [[] for _ in range(self._num_chunks)]
|
||||
self._output_obj_grads = [[] for _ in range(self._num_chunks)]
|
||||
self._moe_losses = [[] for _ in range(self._num_chunks)]
|
||||
|
||||
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
|
||||
self._output_obj_shapes = [None for _ in range(self._num_chunks)]
|
||||
|
|
@ -712,7 +782,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return move_to_device(micro_batch_data)
|
||||
|
||||
def _forward_step(self, engine, chunk_id):
|
||||
def _forward_step(self, engine, chunk_id, moe_loss_coeff=1.0):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
|
@ -734,7 +804,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
||||
|
||||
self._call_hooks("before_forward", data)
|
||||
output_obj = self._call_engine(engine.model[chunk_id], data)
|
||||
output_obj, moe_losses = self._call_engine(engine.model[chunk_id], data)
|
||||
# Convert output_obj to fp32 when last model chunk of last stage
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
|
||||
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
|
||||
|
|
@ -754,7 +824,14 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
self._accum_loss.add_(loss_reduced.detach())
|
||||
output_obj = loss_reduced
|
||||
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
moe_loss /= self.num_microbatches
|
||||
|
||||
if self._accum_moe_loss is not None:
|
||||
self._accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
self._output_objs[chunk_id].append(output_obj)
|
||||
self._moe_losses[chunk_id].append(moe_loss)
|
||||
|
||||
return output_obj
|
||||
|
||||
|
|
@ -780,8 +857,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
input_obj = self._input_objs[chunk_id].pop(0)
|
||||
output_obj = self._output_objs[chunk_id].pop(0)
|
||||
output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
|
||||
moe_loss = self._moe_losses[chunk_id].pop(0)
|
||||
|
||||
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
|
||||
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad, moe_loss)
|
||||
|
||||
return input_obj_grad
|
||||
|
||||
|
|
@ -813,6 +891,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps: int,
|
||||
receive_extra_backward: bool = False,
|
||||
forward_only: bool = False,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Run the warm-up loop and prepare data for the 1F1B stage.
|
||||
|
|
@ -850,12 +929,13 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
for k in range(num_warmup_microsteps):
|
||||
chunk_id = self._get_chunk_by_microbatch(k)
|
||||
|
||||
output_obj = self._forward_step(engine, chunk_id)
|
||||
output_obj = self._forward_step(engine, chunk_id, moe_loss_coeff)
|
||||
|
||||
if forward_only:
|
||||
# when forward-only, no need to save tensors for a backward pass
|
||||
self._input_objs[chunk_id].pop()
|
||||
self._output_objs[chunk_id].pop()
|
||||
self._moe_losses[chunk_id].pop()
|
||||
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
|
|
@ -931,6 +1011,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps: int,
|
||||
num_1f1b_micropairs: int,
|
||||
all_warmup_microsteps: bool = False,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Run the 1F1B loop with overlap.
|
||||
|
|
@ -960,7 +1041,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
|
||||
|
||||
# 1. Forward pass.
|
||||
output_obj = self._forward_step(engine, forward_chunk_id)
|
||||
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
|
||||
|
||||
# 2. Check if the backward input is ready.
|
||||
if backward_async_communicator is not None:
|
||||
|
|
@ -1045,6 +1126,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps: int,
|
||||
num_1f1b_micropairs: int,
|
||||
all_warmup_microsteps: bool = False,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Run the 1F1B loop without overlap.
|
||||
|
|
@ -1066,7 +1148,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
# Forward pass.
|
||||
forward_microstep_id = k + num_warmup_microsteps
|
||||
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
|
||||
output_obj = self._forward_step(engine, forward_chunk_id)
|
||||
output_obj = self._forward_step(engine, forward_chunk_id, moe_loss_coeff)
|
||||
|
||||
# Backward pass.
|
||||
backward_microstep_id = k
|
||||
|
|
@ -1171,7 +1253,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
)
|
||||
)
|
||||
|
||||
def _forward_only_step(self, engine: Engine):
|
||||
def _forward_only_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
|
||||
num_microsteps = self.num_microbatches * self._num_chunks
|
||||
num_warmup_microsteps = num_microsteps
|
||||
|
||||
|
|
@ -1181,9 +1263,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_microsteps,
|
||||
receive_extra_backward=False,
|
||||
forward_only=True,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
def _forward_backward_step(self, engine: Engine):
|
||||
def _forward_backward_step(self, engine: Engine, moe_loss_coeff: float = 1.0):
|
||||
# Compute number of warmup and remaining microbatches.
|
||||
all_warmup_microsteps = False
|
||||
num_microsteps = self.num_microbatches * self._num_chunks
|
||||
|
|
@ -1217,6 +1300,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_microsteps,
|
||||
num_warmup_steps,
|
||||
receive_extra_backward=receive_extra_backward,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
# 2. 1F1B
|
||||
|
|
@ -1225,12 +1309,15 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
num_warmup_steps,
|
||||
num_1f1b_micropairs=num_1f1b_micropairs,
|
||||
all_warmup_microsteps=all_warmup_microsteps,
|
||||
moe_loss_coeff=moe_loss_coeff,
|
||||
)
|
||||
|
||||
# 3. Cooldown
|
||||
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
def forward_backward_step(
|
||||
self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True, moe_loss_coeff=1.0
|
||||
):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
||||
|
|
@ -1250,24 +1337,36 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
|
||||
self.load_batch(engine, data_iter)
|
||||
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
self._accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self._accum_moe_loss = torch.zeros(1, device=get_current_device())
|
||||
|
||||
if return_output_label:
|
||||
self._return_tensors = []
|
||||
|
||||
if forward_only:
|
||||
self._forward_only_step(engine)
|
||||
self._forward_only_step(engine, moe_loss_coeff)
|
||||
else:
|
||||
self._forward_backward_step(engine)
|
||||
self._forward_backward_step(engine, moe_loss_coeff)
|
||||
|
||||
if return_output_label and len(self._return_tensors) > 0:
|
||||
output, label = pack_return_tensors(self._return_tensors)
|
||||
else:
|
||||
output, label = (None, None)
|
||||
|
||||
logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
|
||||
|
||||
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
accum_moe_loss = self._accum_moe_loss
|
||||
|
||||
accum_loss = self._accum_loss
|
||||
if accum_loss is not None:
|
||||
accum_loss += self._accum_moe_loss
|
||||
|
||||
self._clear_state()
|
||||
|
||||
return output, label, accum_loss
|
||||
return output, label, accum_loss, accum_moe_loss
|
||||
|
|
|
|||
|
|
@ -155,5 +155,5 @@ class Trainer:
|
|||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
||||
"""
|
||||
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
return output, label, loss
|
||||
output, label, loss, moe_loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
||||
return output, label, loss, moe_loss
|
||||
|
|
|
|||
|
|
@ -108,67 +108,96 @@ def args_sanity_check():
|
|||
logger.info(f"valid_every: {data.valid_every}")
|
||||
|
||||
# processing the checkpoint config
|
||||
if "enable_save_ckpt" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("enable_save_ckpt", False)
|
||||
ckpt = gpc.config.ckpt
|
||||
if "enable_save_ckpt" not in ckpt:
|
||||
ckpt._add_item("enable_save_ckpt", False)
|
||||
|
||||
if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0:
|
||||
gpc.config.ckpt._add_item("checkpoint_every", float("inf"))
|
||||
# Saving checkpoint args.
|
||||
if ckpt.enable_save_ckpt:
|
||||
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
|
||||
assert ckpt.checkpoint_every > 0
|
||||
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
|
||||
|
||||
if "load_optimizer" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_optimizer", True)
|
||||
if "async_upload" not in ckpt:
|
||||
ckpt._add_item("async_upload", False) # async defalut is False.
|
||||
else:
|
||||
if ckpt.async_upload:
|
||||
assert "save_ckpt_folder" in ckpt
|
||||
if "boto3:" not in ckpt.save_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
||||
)
|
||||
ckpt.async_upload = False
|
||||
else:
|
||||
if "async_upload_tmp_folder" not in ckpt:
|
||||
ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||
|
||||
if "save_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("save_ckpt_folder", None)
|
||||
if not ckpt.async_upload:
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
|
||||
if "load_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_ckpt_folder", None)
|
||||
if "snapshot_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
|
||||
|
||||
if "load_model_only_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("load_model_only_folder", None)
|
||||
if "oss_snapshot_freq" not in ckpt:
|
||||
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
||||
else:
|
||||
ckpt._add_item("checkpoint_every", float("inf"))
|
||||
ckpt._add_item("oss_snapshot_freq", float("inf"))
|
||||
ckpt._add_item("save_ckpt_folder", None)
|
||||
ckpt._add_item("async_upload", False)
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
|
||||
if "async_upload" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("async_upload", False)
|
||||
# Loading checkpoint args.
|
||||
if "load_model_only_folder" not in ckpt:
|
||||
ckpt._add_item("load_model_only_folder", None)
|
||||
|
||||
if "async_upload_tmp_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
||||
if "load_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("load_ckpt_folder", None)
|
||||
|
||||
if gpc.config.ckpt.async_upload:
|
||||
assert "save_ckpt_folder" in gpc.config.ckpt
|
||||
if "boto3:" not in gpc.config.ckpt.save_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Storing ckpt on file system does not support asynchronous storage, will use sync save!")
|
||||
gpc.config.ckpt.async_upload = False
|
||||
if "load_optimizer" not in ckpt:
|
||||
ckpt._add_item("load_optimizer", True)
|
||||
|
||||
if "snapshot_ckpt_folder" not in gpc.config.ckpt:
|
||||
gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot"))
|
||||
if "stop_file_path" not in ckpt:
|
||||
ckpt._add_item("stop_file_path", None)
|
||||
|
||||
if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"):
|
||||
gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2)
|
||||
assert gpc.config.ckpt.oss_snapshot_freq > 0
|
||||
if "load_given_ckpt" not in ckpt:
|
||||
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
|
||||
# to auto-load latest checkpoint.
|
||||
ckpt._add_item("load_given_ckpt", False)
|
||||
|
||||
assert not (
|
||||
gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None
|
||||
), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time."
|
||||
if ckpt.load_given_ckpt:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
|
||||
logger.warning(
|
||||
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||
)
|
||||
ckpt.load_model_only_folder = None
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
||||
logger.info(f"async_upload: {gpc.config.ckpt.async_upload}")
|
||||
if gpc.config.ckpt.async_upload:
|
||||
logger.info(f"async_upload_tmp_folder: {gpc.config.ckpt.async_upload_tmp_folder}")
|
||||
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
|
||||
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
|
||||
|
||||
# initialization storage manager
|
||||
init_storage_manager(gpc.config.ckpt)
|
||||
init_storage_manager(ckpt)
|
||||
|
||||
# tensorboard writer config
|
||||
if "enable_tb" not in gpc.config:
|
||||
gpc.config._add_item("enable_tb", True)
|
||||
if "tensorboard_folder" not in gpc.config:
|
||||
gpc.config._add_item("tensorboard_folder", None)
|
||||
gpc.config._add_item(
|
||||
"tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
|
||||
)
|
||||
if "resume_tb_folder" not in gpc.config:
|
||||
gpc.config._add_item("resume_tb_folder", None)
|
||||
gpc.config._add_item(
|
||||
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
||||
)
|
||||
|
||||
# cudnn
|
||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||
|
|
@ -236,11 +265,13 @@ def args_sanity_check():
|
|||
# process the model config
|
||||
if "use_flash_attn" not in gpc.config.model:
|
||||
gpc.config.model._add_item("use_flash_attn", True)
|
||||
if "sequence_parallel" not in gpc.config.model:
|
||||
gpc.config.model._add_item("sequence_parallel", False)
|
||||
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||
else:
|
||||
assert not (
|
||||
gpc.config.model.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
# feishu webhook address for alerting
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import rotary_emb
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
|
||||
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
|
@ -56,7 +57,7 @@ class Embedding1D(nn.Module):
|
|||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
||||
|
||||
if gpc.config.model.sequence_parallel:
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
||||
|
||||
return output
|
||||
|
|
@ -111,6 +112,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|||
|
||||
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
||||
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
||||
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class ScaleColumnParallelLinear(nn.Linear):
|
|||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ class RewardModelLinear(ScaleColumnParallelLinear):
|
|||
weight,
|
||||
self.bias,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ class FeedForward(nn.Module):
|
|||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
@ -182,7 +182,7 @@ class FeedForward(nn.Module):
|
|||
hidden_features,
|
||||
process_group,
|
||||
bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
@ -191,7 +191,7 @@ class FeedForward(nn.Module):
|
|||
out_features,
|
||||
process_group,
|
||||
bias=bias,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -393,7 +393,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
max_position_embeddings=-1,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
padding_idx=None,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class MoE(torch.nn.Module):
|
|||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
using_default_moe: bool = True,
|
||||
use_residual=True,
|
||||
use_residual=False,
|
||||
residual_mlp=None,
|
||||
):
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ class MHA(nn.Module):
|
|||
3 * embed_dim,
|
||||
process_group,
|
||||
bias=True,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
) # according to https://spaces.ac.cn/archives/9577
|
||||
|
||||
|
|
@ -95,7 +95,11 @@ class MHA(nn.Module):
|
|||
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinearTorch(
|
||||
embed_dim, embed_dim, process_group, sequence_parallel=gpc.config.model.sequence_parallel, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
process_group,
|
||||
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# need to assign tp attribute so that internlm know it is tensor parallel module
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
|
|
|
|||
|
|
@ -356,6 +356,8 @@ class TopKGate(Module):
|
|||
# Only top-1 and top-2 are supported at the moment.
|
||||
if k not in (1, 2):
|
||||
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
||||
# TODO: can we use tensor parallel here?
|
||||
# Deepspeed's mechisms, alway use fp32
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -20,6 +21,7 @@ from internlm.solver.optimizer.store import (
|
|||
)
|
||||
from internlm.solver.optimizer.utils import (
|
||||
DynamicGradScaler,
|
||||
ParamBcastSyncHandler,
|
||||
flatten,
|
||||
get_grad_accumulate_object,
|
||||
has_inf_or_nan,
|
||||
|
|
@ -88,10 +90,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self,
|
||||
optimizer: Optimizer,
|
||||
cpu_offload=False,
|
||||
overlap_broadcast=False,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
has_moe: bool = False,
|
||||
param_bcast_sync_handler: ParamBcastSyncHandler = None,
|
||||
):
|
||||
# DynamicGradScaler related args
|
||||
if gpc.config.model.dtype is torch.float32:
|
||||
|
|
@ -163,7 +165,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
+ f"zo-{self._zero_local_rank}.pt"
|
||||
)
|
||||
self.params_per_rank_id_dict = []
|
||||
self.overlap_broadcast = overlap_broadcast
|
||||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
if self._overlap_communication:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
|
@ -238,6 +242,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# communication-computation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
|
|
@ -284,8 +290,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
if self._overlap_communication:
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
|
@ -322,7 +330,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
reduction_func = partial(
|
||||
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
|
||||
self._store_and_try_reduce_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
# define hook
|
||||
|
|
@ -416,16 +426,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, dp_parallel_mode):
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
stream.synchronize()
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=dp_parallel_mode
|
||||
tensor=flat,
|
||||
dtype=self.dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=dp_parallel_mode,
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
|
|
@ -616,7 +626,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
message="Overflow occurs, please check it.",
|
||||
)
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
|
|
@ -678,37 +691,42 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
# TODO: support broadcast overlap
|
||||
self.broadcast_params(overlap=False)
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
||||
|
||||
def broadcast_params(self, overlap=False):
|
||||
def broadcast_params(self):
|
||||
handles = []
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
|
||||
if self._is_moe_group(self.optim.param_groups[group_id]):
|
||||
continue
|
||||
for rank in range(self._zero_world_size):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
||||
)
|
||||
handles.append(handle)
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(ParallelMode.ZERO1),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
if not overlap:
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
else:
|
||||
return handles
|
||||
if self._overlap_communication:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
|
|
@ -726,7 +744,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
|
||||
dist.all_reduce(
|
||||
self._found_overflow,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.GLOBAL),
|
||||
)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,18 @@
|
|||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
|
@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor):
|
|||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||
buckets = []
|
||||
for _, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
if bucket:
|
||||
buckets.append(bucket)
|
||||
dtype_buckets = {
|
||||
"torch.cuda.HalfTensor": [],
|
||||
"torch.cuda.FloatTensor": [],
|
||||
"torch.cuda.DoubleTensor": [],
|
||||
"torch.cuda.BFloat16Tensor": [],
|
||||
}
|
||||
|
||||
for t in tensor_list:
|
||||
dtype = t.type()
|
||||
if dtype in dtype_buckets:
|
||||
dtype_buckets[dtype].append(t)
|
||||
|
||||
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
|
||||
return buckets
|
||||
|
||||
|
||||
|
|
@ -184,7 +194,10 @@ def calc_l2_norm(grads):
|
|||
if APEX_AVAILABLE:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False, # no per-parameter norm
|
||||
)
|
||||
else:
|
||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||
|
|
@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
dist.all_reduce(
|
||||
total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
|
|
@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
dist.all_reduce(
|
||||
total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
|
|
@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
self._scale = self._scale.fill_(state_dict["_scale"])
|
||||
self._growth_step = state_dict["_growth_step"]
|
||||
self._hysteresis_step = state_dict["_hysteresis_step"]
|
||||
|
||||
|
||||
class ParamBcastSyncHandler:
|
||||
"""
|
||||
Model Partition Handler for overlap broadcast with forward
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
|
||||
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
|
||||
self._param_to_rank = dict() # <key: param> <value: rank)>
|
||||
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
||||
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
||||
|
||||
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
total_param_num = sum(p.numel() for p in model.parameters())
|
||||
avg_param_num = total_param_num * 1.0 // zero1_size
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
# record the parameters to transformer/embeding/head/norm block
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(children, nn.ModuleList):
|
||||
# record the block that a parameter belongs to
|
||||
for _, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
# self._block_to_param[name] = list(children.parameters())
|
||||
self._block_to_param[children] = list(children.parameters())
|
||||
|
||||
alloc_num = 0
|
||||
rank_to_go = 0
|
||||
|
||||
# process the parameters in block_to_param sequencially,
|
||||
# allocate each parameter to a local rank of ParallelMode.ZERO1,
|
||||
# NOTE that we do NOT consider following scenarios:
|
||||
# 1) whether a parameter is trainable;
|
||||
# 2) paramters maybe in different optimizer group
|
||||
for block, params in self._block_to_param.items():
|
||||
# allocate a model block to a local rank of ParallelMode.ZERO1
|
||||
self._block_to_rank[block] = [rank_to_go]
|
||||
for p in params:
|
||||
alloc_num = alloc_num + p.numel()
|
||||
# in this case, allocate the param to next rank if possible
|
||||
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
|
||||
rank_to_go = rank_to_go + 1
|
||||
alloc_num = 0
|
||||
self._block_to_rank[block].append(rank_to_go)
|
||||
# allocate a parameter to a local rank of ParallelMode.ZERO1
|
||||
self._param_to_rank[p] = rank_to_go
|
||||
|
||||
# initialize an empty list for _bcast_handles of each rank
|
||||
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
|
||||
self._bcast_handles[rank] = []
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
bcast_handles = []
|
||||
# gather all required broadcast hanles into a list
|
||||
for rank in self._block_to_rank[model]:
|
||||
bcast_handles.extend(self._bcast_handles[rank])
|
||||
# need to clear _bcast_handles since they would be processed later
|
||||
self._bcast_handles[rank] = []
|
||||
# wait all required broadcast handles to be completed
|
||||
for handle in bcast_handles:
|
||||
handle.wait()
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
for block, _ in self._block_to_rank.items():
|
||||
block.register_forward_pre_hook(partial(_pre_forward_hook))
|
||||
|
||||
def get_rank_by_param(self, param) -> int:
|
||||
return self._param_to_rank[param]
|
||||
|
||||
def add_bcast_handle(self, rank, handle) -> None:
|
||||
self._bcast_handles[rank].append(handle)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
from .training_internlm import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_distributed_env,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_train_data_loader",
|
||||
"get_validation_data_loader",
|
||||
"initialize_distributed_env",
|
||||
"initialize_llm_profile",
|
||||
"initialize_model",
|
||||
"initialize_optimizer",
|
||||
"load_new_batch",
|
||||
"record_current_batch_training_metrics",
|
||||
]
|
||||
|
|
@ -0,0 +1,447 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||
from internlm.data.dataset import get_dataset_dict
|
||||
from internlm.data.dummy_dataset import RandomDataset
|
||||
from internlm.data.packed_dataset import (
|
||||
PackedDataset,
|
||||
PackedDatasetWithoutCuSeqlen,
|
||||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.moe import create_moe_param_groups, has_moe_layers
|
||||
from internlm.monitor import set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import DummyProfile, get_master_node
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param_with_ep,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
internlm.launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
internlm.launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
Returns: The neural network model to be trained or evaluated.
|
||||
"""
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
if isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList(
|
||||
[
|
||||
NaiveAMPModel(
|
||||
model=_m,
|
||||
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
for _m in model
|
||||
]
|
||||
)
|
||||
else:
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
|
||||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
# does not influence the model weights in optimizer be different with the origin parameters.
|
||||
sync_model_param_with_ep(model)
|
||||
|
||||
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
|
||||
# the same across tensor parallelism.
|
||||
sync_model_param_within_tp(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Your model instance to be trained or evaluated.
|
||||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
param_bcast_sync_handler = ParamBcastSyncHandler(model)
|
||||
adam_cfg = gpc.config.adam
|
||||
if gpc.config.model.num_experts > 1:
|
||||
params = create_moe_param_groups(model, adam_cfg.weight_decay)
|
||||
else:
|
||||
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=params,
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
has_moe = has_moe_layers(model)
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
has_moe=has_moe,
|
||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
||||
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
|
||||
|
||||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
def get_train_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
||||
):
|
||||
"""
|
||||
Generate and return the training data loader.
|
||||
|
||||
Returns: A tuple of (train_dl, dataset_types).
|
||||
"""
|
||||
|
||||
# Get the dataset types
|
||||
dataset_types = None
|
||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
# Get the sample weight dictionary
|
||||
train_folder = data_cfg.train_folder
|
||||
|
||||
if not train_folder:
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
if dataset_generate_func is not None:
|
||||
train_ds = dataset_generate_func()
|
||||
else:
|
||||
train_ds = get_packed_dataset_without_short_length(
|
||||
folder=data_cfg.train_folder,
|
||||
packed_length=data_cfg.packed_length,
|
||||
max_length_per_sample=data_cfg.seq_len,
|
||||
show_progress=dist.get_rank() == 0,
|
||||
min_length=data_cfg.min_length,
|
||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||
)
|
||||
|
||||
if dataset_generate_func is None or not train_folder:
|
||||
# partition already completed
|
||||
assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen, ConcatDataset))
|
||||
# Create the training dataset sampler
|
||||
train_sampler = StaticBatchSampler(
|
||||
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
||||
batch_size=data_cfg.micro_num,
|
||||
rampup_batch_size=data_cfg.rampup_batch_size,
|
||||
micro_bsz=data_cfg.micro_bsz,
|
||||
seed=1024,
|
||||
drop_last=True,
|
||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
|
||||
if dataset_generate_func is None or not train_folder:
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
# Create the training data loader
|
||||
train_dl = DataLoader(
|
||||
dataset=train_ds,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=num_worker,
|
||||
pin_memory=True,
|
||||
collate_fn=train_collate_fn,
|
||||
persistent_workers=num_worker > 0,
|
||||
)
|
||||
|
||||
return train_dl, dataset_types
|
||||
|
||||
|
||||
def get_validation_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
|
||||
):
|
||||
"""Generate and return the validation data loader."""
|
||||
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
if not data_cfg.valid_folder:
|
||||
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
|
||||
else:
|
||||
if dataset_generate_func is not None:
|
||||
assert val_collate_fn and dataloader_func is not None
|
||||
val_ds = dataset_generate_func()
|
||||
else:
|
||||
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
|
||||
|
||||
if not isinstance(val_ds, dict):
|
||||
val_ds = {"val": val_ds}
|
||||
|
||||
if val_collate_fn is None or not data_cfg.valid_folder:
|
||||
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
|
||||
|
||||
val_dls = {}
|
||||
for val_name, ds in val_ds.items():
|
||||
if dataloader_func and data_cfg.valid_folder is not None:
|
||||
val_dls[val_name] = dataloader_func(dataset=ds, collate_fn=val_collate_fn)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"load validation dataset {val_name} with valid batch size {str(data_cfg.valid_micro_num)} and "
|
||||
f"{ds.size} Byte samples."
|
||||
)
|
||||
else:
|
||||
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
|
||||
# otherwise too much data may be dropped
|
||||
batch_size = min(
|
||||
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
|
||||
)
|
||||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.")
|
||||
continue
|
||||
|
||||
val_dls[val_name] = get_dpsampler_dataloader(
|
||||
ds,
|
||||
shuffle=False,
|
||||
num_workers=num_worker,
|
||||
batch_size=batch_size,
|
||||
collate_fn=val_collate_fn,
|
||||
drop_last=True,
|
||||
) # drop_last=True, otherwise it may cause problems in the last batch
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
||||
f"samples {str(len(val_dls[val_name]))}."
|
||||
)
|
||||
|
||||
return val_dls
|
||||
|
||||
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
||||
Args:
|
||||
train_dl (torch.utils.data.DataLoader): Dataloader for training.
|
||||
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
train_state (TrainState): Current training state.
|
||||
|
||||
Returns: A batch data and the updated train_iter.
|
||||
"""
|
||||
|
||||
timer("batch-gen").start()
|
||||
try:
|
||||
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
|
||||
if hasattr(train_state, "batch_sampler_iter"):
|
||||
next(train_state.batch_sampler_iter)
|
||||
except StopIteration:
|
||||
train_iter = iter(train_dl)
|
||||
batch = next(train_iter)
|
||||
train_state.num_consumed_samples_in_epoch = 0
|
||||
if hasattr(train_state, "batch_sampler"):
|
||||
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
|
||||
next(train_state.batch_sampler_iter)
|
||||
timer("batch-gen").stop()
|
||||
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
||||
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||
"""Initialize and return the profiler context manager instance."""
|
||||
|
||||
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
llm_profile = torch.profiler.profile
|
||||
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
|
||||
else:
|
||||
llm_profile = DummyProfile
|
||||
|
||||
return llm_profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
|
||||
),
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
)
|
||||
|
||||
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
writer,
|
||||
success_update,
|
||||
batch_count,
|
||||
batch,
|
||||
train_state,
|
||||
optimizer,
|
||||
beta2_scheduler,
|
||||
trainer,
|
||||
start_time,
|
||||
loss,
|
||||
moe_loss,
|
||||
grad_norm,
|
||||
metric,
|
||||
update_panel,
|
||||
):
|
||||
"""
|
||||
Print some training metrics of current batch.
|
||||
"""
|
||||
|
||||
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
|
||||
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
if is_no_pp_or_last_stage():
|
||||
acc_perplex = metric.get_metric()
|
||||
|
||||
if success_update and gpc.is_rank_for_log():
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
|
||||
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
|
||||
|
||||
num_tokens_in_batch = batch[1].nelement()
|
||||
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
|
||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
|
||||
tk_per_gpu = 0
|
||||
tk_per_gpu = round(
|
||||
num_tokens_in_batch
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
/ gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
/ (time.time() - start_time),
|
||||
2,
|
||||
)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
}
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
|
||||
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
|
||||
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
|
||||
infos["largest_length"] = max_length_in_batch # the longest input
|
||||
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
|
||||
infos["smallest_batch"] = min_samples_in_batch
|
||||
infos["adam_beta2"] = beta2_scheduler.get_beta2()
|
||||
|
||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||
infos["fwd_bwd_time"] = fwd_bwd_time
|
||||
|
||||
for key, value in acc_perplex.items():
|
||||
infos[key] = value
|
||||
|
||||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value} "
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
line,
|
||||
extra={
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"grad_norm": grad_norm,
|
||||
"loss": loss.item(),
|
||||
"moe_loss": moe_loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(line)
|
||||
|
||||
# if loss spike occurs, send alert info to feishu
|
||||
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
||||
|
|
@ -218,3 +218,21 @@ def get_megatron_flops(
|
|||
|
||||
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||
return tflops
|
||||
|
||||
|
||||
class DummyProfile:
|
||||
"""
|
||||
Dummy Profile.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, a, b, c):
|
||||
pass
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -50,6 +50,16 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
|||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.parallel.sequence_parallel
|
||||
try:
|
||||
gpc.config.parallel.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.parallel.sequence_parallel = prev_mode
|
||||
|
||||
|
||||
def evaluate_on_val_dls(
|
||||
trainer,
|
||||
val_dls,
|
||||
|
|
@ -57,110 +67,102 @@ def evaluate_on_val_dls(
|
|||
logger,
|
||||
step_count,
|
||||
update_panel: bool = False,
|
||||
streaming: bool = False,
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
trainer.eval()
|
||||
verbose = gpc.is_rank_for_log()
|
||||
data_cfg = gpc.config.data
|
||||
with switch_sequence_parallel_mode():
|
||||
torch.cuda.empty_cache()
|
||||
trainer.eval()
|
||||
verbose = gpc.is_rank_for_log()
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
for val_name, val_dl in val_dls.items():
|
||||
if len(val_dl) == 0 and verbose:
|
||||
logger.info(f"Validation dataset: {val_name} is empty")
|
||||
continue
|
||||
for val_name, val_dl in val_dls.items():
|
||||
if len(val_dl) == 0 and verbose and not streaming:
|
||||
logger.info(f"Validation dataset: {val_name} is empty")
|
||||
continue
|
||||
|
||||
val_metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
)
|
||||
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
||||
val_metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||
)
|
||||
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
||||
|
||||
val_loss = 0
|
||||
val_idx = -1
|
||||
for val_idx, batch in tqdm(
|
||||
enumerate(val_dl),
|
||||
desc="Val.",
|
||||
total=len(val_dl),
|
||||
position=1,
|
||||
disable=not verbose,
|
||||
leave=False,
|
||||
):
|
||||
with torch.inference_mode():
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
val_loss = 0
|
||||
val_idx = -1
|
||||
for val_idx, batch in tqdm(
|
||||
enumerate(val_dl),
|
||||
desc="Val.",
|
||||
total=len(val_dl) if not streaming else None,
|
||||
position=1,
|
||||
disable=not verbose,
|
||||
leave=False,
|
||||
):
|
||||
with torch.inference_mode():
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
num_microbatches=num_microbatches,
|
||||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss, _ = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss, _ = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
val_loss += loss.item()
|
||||
|
||||
assert val_idx != -1
|
||||
dist.barrier()
|
||||
|
||||
val_res = val_metric.get_metric()
|
||||
if verbose and len(val_dl) != 0:
|
||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||
infos = {
|
||||
"step": step_count,
|
||||
f"val/{val_name}_loss": val_loss,
|
||||
f"val/{val_name}_acc": val_res["acc"],
|
||||
f"val/{val_name}_plex": val_res["perplexity"],
|
||||
}
|
||||
|
||||
for key, value in infos.items():
|
||||
writer.add_scalar(key=key, value=value, step=step_count)
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||
extra={
|
||||
"step": step_count,
|
||||
"val_loss": val_loss,
|
||||
"val_acc": val_res["acc"],
|
||||
"val_perplexity": val_res["perplexity"],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
num_microbatches=num_microbatches,
|
||||
tensor_shape=tensor_shape,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||
)
|
||||
if verbose:
|
||||
val_loss += loss.item()
|
||||
|
||||
assert val_idx != -1
|
||||
trainer.train()
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
|
||||
val_res = val_metric.get_metric()
|
||||
if verbose and len(val_dl) != 0:
|
||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||
infos = {
|
||||
"step": step_count,
|
||||
f"val/{val_name}_loss": val_loss,
|
||||
f"val/{val_name}_acc": val_res["acc"],
|
||||
f"val/{val_name}_plex": val_res["perplexity"],
|
||||
}
|
||||
|
||||
for key, value in infos.items():
|
||||
writer.add_scalar(key=key, value=value, step=step_count)
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||
extra={
|
||||
"step": step_count,
|
||||
"val_loss": val_loss,
|
||||
"val_acc": val_res["acc"],
|
||||
"val_perplexity": val_res["perplexity"],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.model.sequence_parallel
|
||||
try:
|
||||
gpc.config.model.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
||||
|
|
|
|||
|
|
@ -14,18 +14,19 @@ class _Timer:
|
|||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
self.start_time = time.time()
|
||||
self.stream = torch.cuda.current_stream()
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
assert not self.started_, "timer has already been started"
|
||||
torch.cuda.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
torch.cuda.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import fcntl
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
|
|
@ -12,6 +16,8 @@ import torch
|
|||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.model.moe import MoE
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
|
|
@ -25,8 +31,6 @@ from internlm.utils.storage_manager import (
|
|||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
quit_signal_handler = None
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
NORMAL_CHECKPOINT = 1
|
||||
|
|
@ -69,6 +73,8 @@ def save_model_checkpoint(folder, model):
|
|||
"""
|
||||
|
||||
states = model.state_dict()
|
||||
# get non-moe parameters
|
||||
states = get_non_moe_state_dict(states)
|
||||
topo = get_model_topology(model)
|
||||
|
||||
if folder is not None:
|
||||
|
|
@ -92,6 +98,9 @@ def save_model_checkpoint(folder, model):
|
|||
topo_fp = os.path.join(folder, topo_fn)
|
||||
llm_save(topo_fp, saved_obj=topo)
|
||||
|
||||
# move the judgement logic into save_moe_checkpoint(.)
|
||||
try_save_moe_checkpoint(folder, model)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
|
|
@ -128,6 +137,18 @@ def load_model_checkpoint(folder, model):
|
|||
fp = os.path.join(folder, should_load_name)
|
||||
states = llm_load(fp, map_location=get_current_device())
|
||||
|
||||
"""
|
||||
# need convert the gate parameters to float32 (to fit deepspeed style mechanism), it may cause round-off in
|
||||
# gate.weight. The conversion will also be done when doing forward. so we can just comment it out. this make
|
||||
# the gate parameters to be float16 before forward.
|
||||
for key in list(states.keys()):
|
||||
if 'moe_layer.gate.wg.weight' in key:
|
||||
states[key] = states[key].float()
|
||||
print("load: ", states[key].float(),flush=True)
|
||||
"""
|
||||
|
||||
try_load_moe_checkpoint(folder, model, states)
|
||||
|
||||
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
||||
if len(missing_k) != 0:
|
||||
logger.warning(f"Warning: missing keys {missing_k}")
|
||||
|
|
@ -139,6 +160,58 @@ def load_model_checkpoint(folder, model):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def try_save_moe_checkpoint(folder, model):
|
||||
# Using layer_#_expert_# to save the model's expert state_dict,a hack.
|
||||
moe_layer_id = 0
|
||||
for n_module, module in model.named_modules():
|
||||
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||
num_local_experts = module.num_local_experts
|
||||
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
|
||||
|
||||
# get all moe parameters
|
||||
moe_state_dict = {}
|
||||
for n, p in module.state_dict().items():
|
||||
if "expert" in n and "moe_layer.gate.wg.weight" not in n:
|
||||
moe_state_dict[n_module + "." + n] = p
|
||||
moe_str_prefix = ".moe_layer.experts.experts."
|
||||
# Reorder the moe name rank, so that each checkpoint only has one expert
|
||||
experts_state_dict = defaultdict(dict)
|
||||
for key in list(moe_state_dict.keys()):
|
||||
m = re.match(f".*{moe_str_prefix}([0-9]+).*", key)
|
||||
|
||||
local_expert_id = None
|
||||
if not m:
|
||||
logger.warning(f"No expert found in key {key}.")
|
||||
else:
|
||||
local_expert_id = m.group(1)
|
||||
|
||||
global_expert_id = expp_rank * num_local_experts + int(local_expert_id)
|
||||
expert_key = key.replace(f"{moe_str_prefix}{local_expert_id}", f"{moe_str_prefix}{global_expert_id}")
|
||||
|
||||
# truncating extra tensor (shared) storage
|
||||
truncated = moe_state_dict.pop(key).clone().detach()
|
||||
experts_state_dict[str(global_expert_id)][expert_key] = truncated
|
||||
|
||||
# let save the moe parameters
|
||||
for global_expert_id, expert_state_dict in experts_state_dict.items():
|
||||
# save the moe parameters
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
llm_save(fp, saved_obj=expert_state_dict)
|
||||
moe_layer_id += 1
|
||||
|
||||
|
||||
def get_non_moe_state_dict(full_state_dict):
|
||||
"""
|
||||
Get the state dict of the non-moe layers
|
||||
"""
|
||||
for key in list(full_state_dict.keys()):
|
||||
if "expert" in key and "moe_layer.gate.wg.weight" not in key:
|
||||
full_state_dict.pop(key)
|
||||
|
||||
return full_state_dict
|
||||
|
||||
|
||||
def save_optimizer_checkpoint(optim, state_path):
|
||||
"""Store the state of the optimizer to the local file system or remote OSS.
|
||||
|
||||
|
|
@ -167,42 +240,25 @@ def save_optimizer_checkpoint(optim, state_path):
|
|||
llm_save(os.path.join(state_path, fp), states)
|
||||
|
||||
|
||||
def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
||||
"""
|
||||
Save checkpoint to the given folder path.
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
torch.distributed.barrier()
|
||||
folder = os.path.join(folder, str(train_state.step_count))
|
||||
logger.info(
|
||||
f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..."
|
||||
)
|
||||
|
||||
timer("save-model").start()
|
||||
save_model_checkpoint(folder=folder, model=model)
|
||||
timer("save-model").stop()
|
||||
|
||||
timer("save-optimizer").start()
|
||||
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
||||
timer("save-optimizer").stop()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
|
||||
sampler_state = train_state.batch_sampler.state_dict()
|
||||
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
||||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||
|
||||
if model_config is not None:
|
||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
timer.log(["save-model", "save-optimizer"], logger=logger)
|
||||
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
||||
def try_load_moe_checkpoint(folder, model, state_dict):
|
||||
moe_layer_id = 0
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, MoE): # and deepspeed.comm.get_rank() == 0:
|
||||
num_local_experts = module.num_local_experts
|
||||
expp_rank = gpc.get_local_rank(ParallelMode.EXPERT)
|
||||
# loop all local_experts
|
||||
for local_expert_id in range(num_local_experts):
|
||||
global_expert_id = expp_rank * num_local_experts + local_expert_id
|
||||
fn = f"model_moe_layer{moe_layer_id}_expert{global_expert_id}.pt"
|
||||
fp = os.path.join(folder, fn)
|
||||
expert_state_dict = llm_load(fp, map_location=get_current_device())
|
||||
# Updating global -> local expert ids
|
||||
moe_str_prefix = ".moe_layer.experts.experts."
|
||||
for key in list(expert_state_dict.keys()):
|
||||
local_key = key.replace(f"{moe_str_prefix}{global_expert_id}", f"{moe_str_prefix}{local_expert_id}")
|
||||
expert_state_dict[local_key] = expert_state_dict.pop(key)
|
||||
state_dict.update(expert_state_dict)
|
||||
moe_layer_id += 1
|
||||
|
||||
|
||||
def load_optimizer_checkpoint(folder, optim):
|
||||
|
|
@ -304,19 +360,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
|||
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
||||
|
||||
|
||||
class CheckpointSaveManager:
|
||||
class CheckpointManager:
|
||||
"""StorageManagerContext"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_config,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
model_config,
|
||||
) -> None:
|
||||
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None:
|
||||
"""
|
||||
CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous
|
||||
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
||||
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||
for the asynchronous ckpt upload to complete.
|
||||
|
||||
|
|
@ -332,26 +381,95 @@ class CheckpointSaveManager:
|
|||
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
||||
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
||||
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
||||
self.stop_file_path = ckpt_config.stop_file_path
|
||||
self.load_model_only_folder = ckpt_config.load_model_only_folder
|
||||
self.feishu_address = feishu_address
|
||||
self.storage_manager = get_storage_manager()
|
||||
self.snapshot_counter = 0
|
||||
self.load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.model_config = model_config
|
||||
|
||||
if self.stop_file_path and gpc.get_global_rank() == 0:
|
||||
dir_path = os.path.dirname(self.stop_file_path)
|
||||
if dir_path != "" and not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
with open(self.stop_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("0")
|
||||
|
||||
if ckpt_config.load_given_ckpt is False:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
latest_ckpt_path = self.query_lastest_ckpt()
|
||||
if latest_ckpt_path:
|
||||
self.load_ckpt_folder = latest_ckpt_path
|
||||
else:
|
||||
# At this time, we have to load model init weights and train from step 0.
|
||||
self.load_ckpt_folder = self.load_model_only_folder
|
||||
else:
|
||||
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
|
||||
if self.stop_file_path is None:
|
||||
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||
|
||||
def quit_signal_handler(self, train_state) -> bool:
|
||||
"""
|
||||
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
|
||||
all ranks will save ckpt and exit.
|
||||
Negative integer step means save ckpt.
|
||||
Positive integer step means save ckpt and quit.
|
||||
|
||||
Args:
|
||||
train_state (TrainState):
|
||||
Returns:
|
||||
bool: whether to quit.
|
||||
"""
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
|
||||
|
||||
if self.stop_file_path is None:
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
|
||||
fcntl.flock(f, fcntl.LOCK_EX)
|
||||
f.seek(0)
|
||||
msg = f.read()
|
||||
fcntl.flock(f, fcntl.LOCK_UN)
|
||||
action_step = int(msg)
|
||||
|
||||
if action_step < 0 and abs(action_step) == train_state.step_count:
|
||||
now_save_ckpt = True
|
||||
|
||||
if action_step > 0 and action_step == train_state.step_count:
|
||||
now_break, now_save_ckpt = True, True
|
||||
|
||||
if action_step != 0 and gpc.is_rank_for_log():
|
||||
msg = "Stop" if action_step > 0 else "Save"
|
||||
action_step = abs(action_step)
|
||||
if train_state.step_count <= action_step:
|
||||
if self.feishu_address:
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"training will {msg} at step_count {action_step}!\
|
||||
now step_count is {train_state.step_count}",
|
||||
)
|
||||
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
def try_save_checkpoint(self, train_state):
|
||||
if not self.enable_save_ckpt:
|
||||
return
|
||||
return False
|
||||
|
||||
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
||||
if train_state.step_count % self.checkpoint_every == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
||||
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
||||
if save_ckpts is False:
|
||||
if quit_signal_handler is not None:
|
||||
save_ckpts, save_type = quit_signal_handler(train_state)
|
||||
save_ckpts = singal_save_ckpts
|
||||
save_type = singal_save_type
|
||||
|
||||
if save_ckpts:
|
||||
# Wait for the previous round of asynchronous upload storage to complete.
|
||||
|
|
@ -361,9 +479,9 @@ class CheckpointSaveManager:
|
|||
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
||||
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
||||
else:
|
||||
save_ckpt_folder = self.save_ckpt_folder
|
||||
save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count))
|
||||
|
||||
save_checkpoint(
|
||||
self.save_checkpoint(
|
||||
folder=save_ckpt_folder,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
|
|
@ -372,7 +490,221 @@ class CheckpointSaveManager:
|
|||
model_config=self.model_config,
|
||||
)
|
||||
|
||||
return now_break
|
||||
|
||||
def wait_async_upload_finish(self):
|
||||
"""wait for all checkpoint uploads to be completed"""
|
||||
self.storage_manager.wait()
|
||||
torch.distributed.barrier()
|
||||
|
||||
def query_latest_snapshot_step_boto3(self):
|
||||
"""query_latest_snapshot_step_boto3
|
||||
Returns:
|
||||
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
||||
"""
|
||||
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
|
||||
if len(ckpt_list) == 0:
|
||||
return None, None
|
||||
|
||||
max_normal_step = 0
|
||||
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
|
||||
ckpt_list.sort(reverse=True)
|
||||
for ckpt in ckpt_list:
|
||||
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
||||
for fn in fns_list:
|
||||
if fn.endswith(".step"):
|
||||
max_normal_step = ckpt
|
||||
break
|
||||
if max_normal_step != 0:
|
||||
break
|
||||
|
||||
max_normal_step = ckpt_list[0]
|
||||
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
|
||||
|
||||
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
|
||||
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
|
||||
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
||||
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
||||
max_step_0, max_step_1 = 0, 0
|
||||
for ckpt in ckpt_list_1:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||
for ckpt in ckpt_list_2:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||
|
||||
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
||||
snap_step = max(max_step_0, max_step_1)
|
||||
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
|
||||
load_step = max(snap_step, max_normal_step)
|
||||
return load_path, load_step
|
||||
|
||||
def query_latest_snapshot_step_local(self):
|
||||
max_step, max_step_path = 0, None
|
||||
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
|
||||
for fn in files:
|
||||
fn = fn.strip("/")
|
||||
if fn.endswith(".step"):
|
||||
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
|
||||
# as an integrity flag.
|
||||
step = int(fn.rsplit(".", maxsplit=1)[0])
|
||||
if max_step < step:
|
||||
max_step = step
|
||||
max_step_path = root
|
||||
|
||||
return max_step_path, max_step
|
||||
|
||||
def query_lastest_ckpt(self):
|
||||
latest_checkpoint = None
|
||||
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
||||
if self.save_ckpt_folder:
|
||||
if self.save_ckpt_folder.startswith("boto3"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
|
||||
elif self.save_ckpt_folder.startswith("local"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_local()
|
||||
else:
|
||||
latest_checkpoint, step = None, 0
|
||||
|
||||
if latest_checkpoint is not None:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
|
||||
)
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
|
||||
)
|
||||
|
||||
return latest_checkpoint
|
||||
|
||||
def try_load_model(self, current_time=""):
|
||||
model_load_path = None
|
||||
|
||||
if self.load_ckpt_folder and self.load_model_only_folder:
|
||||
raise ValueError(
|
||||
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
|
||||
if you only need to load model weights (for example starting an SFT task for the first time), \
|
||||
set load_model_only_folder path, if you need to resume training from ckpt, \
|
||||
set load_ckpt_folder or use default value \
|
||||
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
|
||||
)
|
||||
|
||||
if self.load_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_ckpt_folder
|
||||
elif self.load_model_only_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_model_only_folder
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=self.model)
|
||||
|
||||
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
|
||||
"""Attempt to restore the training state of the last ckpt.
|
||||
|
||||
Args:
|
||||
lr_scheduler (_LRScheduler): lr_scheduler object.
|
||||
optimizer (Optimizer): optimizer object.
|
||||
lr (float): learning rate.
|
||||
train_state (dict): traing states.
|
||||
train_dl (DataLoader): traning dataloader object
|
||||
"""
|
||||
if self.load_ckpt_folder is not None:
|
||||
# load optimzier states.
|
||||
if self.load_optimizer:
|
||||
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
|
||||
# load lr scheduler states.
|
||||
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(self.load_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
|
||||
if hasattr(train_state, "data_state_dict"):
|
||||
train_dl.dataset.load_state_dict(
|
||||
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
||||
"""
|
||||
Save checkpoint to the given folder path.
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
self.set_save_folder(folder, train_state.step_count)
|
||||
torch.cuda.synchronize()
|
||||
torch.distributed.barrier()
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
|
||||
|
||||
timer("save-model").start()
|
||||
save_model_checkpoint(folder=folder, model=model)
|
||||
timer("save-model").stop()
|
||||
|
||||
timer("save-optimizer").start()
|
||||
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
||||
timer("save-optimizer").stop()
|
||||
|
||||
if (
|
||||
hasattr(train_state, "data_state_dict")
|
||||
and gpc.get_local_rank(ParallelMode.TENSOR) == 0
|
||||
and gpc.get_local_rank(ParallelMode.PIPELINE) == 0
|
||||
):
|
||||
llm_save(
|
||||
os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"),
|
||||
saved_obj=train_state.data_state_dict,
|
||||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
sampler_state = train_state.batch_sampler.state_dict()
|
||||
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
||||
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
||||
|
||||
if model_config is not None:
|
||||
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
||||
|
||||
torch.distributed.barrier()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
timer.log(["save-model", "save-optimizer"], logger=logger)
|
||||
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
||||
if self.storage_manager.async_mode is False:
|
||||
llm_save(
|
||||
os.path.join(folder, f"{train_state.step_count}.step"),
|
||||
saved_obj=dict({"step": train_state.step_count}),
|
||||
)
|
||||
|
||||
def set_save_folder(self, folder, step):
|
||||
self.storage_manager.latest_save_folder = folder
|
||||
self.storage_manager.latest_save_step = step
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from functools import partial, reduce
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import pyecharts
|
||||
import torch
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
|
||||
mb = 1024 * 1024
|
||||
|
||||
|
|
@ -107,6 +105,8 @@ class SimpleMemState:
|
|||
"""
|
||||
Update the total memory usage of the model and sub-models.
|
||||
"""
|
||||
self._total_mem = self._layer_mem
|
||||
|
||||
for stat in self.sub_model_stats.values():
|
||||
# Update sub-model status first.
|
||||
stat.update_total_memory()
|
||||
|
|
@ -169,6 +169,39 @@ class SimpleMemState:
|
|||
return {"name": self.layer_name, "children": children}
|
||||
|
||||
|
||||
class ActivationMemState:
|
||||
"""
|
||||
Activation Memory State
|
||||
"""
|
||||
|
||||
def __init__(self, num_chunks: int) -> None:
|
||||
self._num_chunks = num_chunks
|
||||
|
||||
self.inited: List[bool] = [False for _ in range(num_chunks)]
|
||||
self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)]
|
||||
|
||||
@property
|
||||
def total_mem(self) -> int:
|
||||
return sum(state.total_mem for state in self.states)
|
||||
|
||||
def dump(self, prefix: str = "") -> str:
|
||||
return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states])
|
||||
|
||||
def to_json(self, base: int = 1024 * 1024) -> List:
|
||||
return [state.to_json(base) for state in self.states]
|
||||
|
||||
|
||||
def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]:
|
||||
num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1
|
||||
|
||||
if num_chunks > 1:
|
||||
model = torch.nn.ModuleList([_model.model if isinstance(_model, NaiveAMPModel) else _model for _model in model])
|
||||
else:
|
||||
model = model.model if isinstance(model, NaiveAMPModel) else model
|
||||
|
||||
return model, num_chunks
|
||||
|
||||
|
||||
class SimpleMemoryProfiler:
|
||||
"""
|
||||
A memory profiler for a llm model.
|
||||
|
|
@ -177,7 +210,7 @@ class SimpleMemoryProfiler:
|
|||
model (torch.nn.Module): The model to profile.
|
||||
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
|
||||
log_file (str): The file to write the memory state information to.
|
||||
activation_config (List[str], optional): The list of activation layers to track. Defaults to None.
|
||||
total_steps: number of steps to trace.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -186,9 +219,8 @@ class SimpleMemoryProfiler:
|
|||
optimizer: torch.optim.Optimizer,
|
||||
log_folder: str,
|
||||
total_steps: int = 5,
|
||||
activation_config: List[str] = None,
|
||||
):
|
||||
self._model = model
|
||||
self._model, self._num_model_chunks = _unpack_naive_wrapper(model)
|
||||
self._optimizer = optimizer
|
||||
self._log_folder = log_folder
|
||||
self._remaining_steps = total_steps
|
||||
|
|
@ -197,17 +229,20 @@ class SimpleMemoryProfiler:
|
|||
self._record_start_time = time.time()
|
||||
|
||||
# For activation memory state.
|
||||
self._activation_config = activation_config
|
||||
self._activation_mem_inited: bool = False
|
||||
|
||||
self._activation_mem: int = 0
|
||||
self._activation_max_count = 0
|
||||
self._activation_base_mem: SimpleMemState = SimpleMemState("activations")
|
||||
self._activation_mem_max: int = 0
|
||||
self._activation_base_mems = ActivationMemState(self._num_model_chunks)
|
||||
|
||||
# Check or create log folder
|
||||
os.makedirs(self._log_folder, exist_ok=True)
|
||||
|
||||
# Register activation memory tracking hooks
|
||||
self._register_activation_trace_hooks()
|
||||
if self._num_model_chunks > 1:
|
||||
for chunk_id in range(self._num_model_chunks):
|
||||
self._register_activation_trace_hooks(chunk_id, self._model[chunk_id])
|
||||
else:
|
||||
self._register_activation_trace_hooks(0, self._model)
|
||||
|
||||
# Calculate static parameter cuda memory
|
||||
self._param_mem_state = SimpleMemState("param_mem")
|
||||
|
|
@ -221,7 +256,7 @@ class SimpleMemoryProfiler:
|
|||
self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups)))
|
||||
|
||||
# Generate the first memory record
|
||||
self.point(create=True)
|
||||
self.point(with_options="params,grads,os_params", create=True)
|
||||
|
||||
def point(self, with_options: str = "", create: bool = False) -> None:
|
||||
"""
|
||||
|
|
@ -272,7 +307,7 @@ class SimpleMemoryProfiler:
|
|||
if "os_state" in options:
|
||||
layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump()
|
||||
if "activation_base" in options:
|
||||
layout_info += "activation_base_layout:\n" + self._activation_base_mem.dump()
|
||||
layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump()
|
||||
|
||||
# Write memory state information to log file
|
||||
file_mode = "w" if create else "a"
|
||||
|
|
@ -315,14 +350,14 @@ class SimpleMemoryProfiler:
|
|||
[self._os_params_mem_state.to_json(), self._os_state_mem_state.to_json()],
|
||||
"os_memory_sunburst",
|
||||
)
|
||||
self._render_sunburst_chart(self._activation_base_mem.to_json()["children"], "activation_memory_sunburst")
|
||||
self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst")
|
||||
# Generate summary sunburst chart
|
||||
summary_sunburst_data = [
|
||||
{"name": "params", "value": self._param_mem_state.total_mem // mb},
|
||||
{"name": "grads", "value": self._grad_mem_state.total_mem // mb},
|
||||
{"name": "os_params", "value": self._os_params_mem_state.total_mem // mb},
|
||||
{"name": "os_state", "value": self._os_state_mem_state.total_mem // mb},
|
||||
{"name": "activation", "value": self._activation_base_mem.total_mem // mb},
|
||||
{"name": "activation", "value": self._activation_mem_max // mb},
|
||||
]
|
||||
|
||||
self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst")
|
||||
|
|
@ -337,12 +372,13 @@ class SimpleMemoryProfiler:
|
|||
{},
|
||||
{
|
||||
"r0": "10%",
|
||||
"r": "40%",
|
||||
"r": "35%",
|
||||
"itemStyle": {"borderWidth": 3},
|
||||
"label": {"align": "left"},
|
||||
},
|
||||
{"r0": "40%", "r": "65%", "label": {"align": "left"}},
|
||||
{"r0": "65%", "r": "80%", "label": {"align": "left"}},
|
||||
{"r0": "35%", "r": "55%", "label": {"align": "left"}},
|
||||
{"r0": "55%", "r": "70%", "label": {"align": "left"}},
|
||||
{"r0": "70%", "r": "80%", "label": {"align": "left"}},
|
||||
{"r0": "80%", "r": "90%", "label": {"align": "left"}},
|
||||
{
|
||||
"r0": "90%",
|
||||
|
|
@ -357,7 +393,14 @@ class SimpleMemoryProfiler:
|
|||
f"{self._log_folder}/{name}.html"
|
||||
)
|
||||
|
||||
def _inner_activation_trace_hook(self, layer_name: str, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
def _inner_activation_trace_hook(
|
||||
self,
|
||||
chunk_id: int,
|
||||
layer_name: str,
|
||||
model: Any,
|
||||
inputs: Any,
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a inner layer.
|
||||
|
||||
|
|
@ -373,13 +416,15 @@ class SimpleMemoryProfiler:
|
|||
del model, inputs
|
||||
assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}"
|
||||
|
||||
if self._stoped or self._activation_mem_inited:
|
||||
if self._stoped or self._activation_base_mems.inited[chunk_id]:
|
||||
return
|
||||
|
||||
# Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook.
|
||||
self._activation_base_mem.add(layer_name, output.element_size() * output.nelement(), flush=False)
|
||||
self._activation_base_mems.states[chunk_id].add(
|
||||
layer_name, output.element_size() * output.nelement(), flush=False
|
||||
)
|
||||
|
||||
def _activation_trace_hook_forward(self, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a forward pass.
|
||||
|
||||
|
|
@ -398,23 +443,24 @@ class SimpleMemoryProfiler:
|
|||
return
|
||||
|
||||
# Check if the activation memory has been initialized
|
||||
if self._activation_mem_inited is False:
|
||||
if self._activation_base_mems.inited[chunk_id] is False:
|
||||
self._activation_base_mems.inited[chunk_id] = True
|
||||
# Update the total memory of the activation base memory state
|
||||
self._activation_base_mem.update_total_memory()
|
||||
self._activation_base_mems.states[chunk_id].update_total_memory()
|
||||
# Set with_options to "activation_base" to include activation_base_layout in the memory dump
|
||||
self._activation_mem_inited = True
|
||||
with_options = "activation_base"
|
||||
else:
|
||||
with_options = ""
|
||||
|
||||
# Accumulate activation memory usage for each forward pass
|
||||
self._activation_mem += self._activation_base_mem.total_mem
|
||||
|
||||
# Update activation max count
|
||||
if self._activation_mem // self._activation_base_mem.total_mem > self._activation_max_count:
|
||||
self._activation_max_count = self._activation_mem // self._activation_base_mem.total_mem
|
||||
self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem
|
||||
if self._activation_mem > self._activation_mem_max:
|
||||
self._activation_mem_max = self._activation_mem
|
||||
|
||||
# Trigger a memory record
|
||||
self.point()
|
||||
self.point(with_options)
|
||||
|
||||
def _activation_tarce_hook_backward(self, model: Any, inputs: Any, grad_outputs: Any) -> None:
|
||||
def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a backward pass.
|
||||
|
||||
|
|
@ -432,37 +478,28 @@ class SimpleMemoryProfiler:
|
|||
return
|
||||
|
||||
# Release activation memory usage for each backward pass
|
||||
self._activation_mem -= self._activation_base_mem.total_mem
|
||||
self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem
|
||||
|
||||
# Trigger a memory record
|
||||
self.point()
|
||||
|
||||
def _register_activation_trace_hooks(self) -> None:
|
||||
def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None:
|
||||
"""
|
||||
Register activation trace hooks for the model and each submodule in the model.
|
||||
"""
|
||||
|
||||
# Register inner activation trace hooks for each submodule in the model
|
||||
for layer_name in self._activation_config:
|
||||
# Register a hook for every activation
|
||||
model = self._model
|
||||
sub_models = layer_name.split(".")
|
||||
# Get the target sub-model
|
||||
for sub_model_name in sub_models:
|
||||
try:
|
||||
model = model.get_submodule(sub_model_name)
|
||||
except AttributeError:
|
||||
model = None
|
||||
break
|
||||
|
||||
for layer_name, sub_model in model_chunk.named_modules():
|
||||
# Register the hook
|
||||
if model is not None:
|
||||
model.register_forward_hook(partial(self._inner_activation_trace_hook, layer_name))
|
||||
if len(sub_model._modules) != 0:
|
||||
continue # TODO: in some special cases, we may need some additional configuration to correct
|
||||
|
||||
sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name))
|
||||
|
||||
# Register a forward hook for the main model to track activation memory usage
|
||||
self._model.register_forward_hook(self._activation_trace_hook_forward)
|
||||
model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id))
|
||||
# Register a backward hook for the main model to release activation memory usage
|
||||
self._model.register_full_backward_hook(self._activation_tarce_hook_backward)
|
||||
model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id))
|
||||
|
||||
def _calc_tensor_memory(
|
||||
self, root_stat: SimpleMemState, named_tensors: Dict[str, torch.Tensor], require_grad: bool = False
|
||||
|
|
@ -554,48 +591,6 @@ class SimpleMemoryProfiler:
|
|||
self._calc_tensor_memory(root_stat, named_tensors)
|
||||
|
||||
|
||||
def build_activation_config(num_layers: int, num_chunks: int = 1) -> List[str]:
|
||||
# TODO: support interleaved pipeline scheduling.
|
||||
assert num_chunks == 1, "Only support num_chunks == 1"
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
else:
|
||||
pipeline_size = 1
|
||||
pipeline_rank = 0
|
||||
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
start, end = parts[0]
|
||||
num_blocks = end - start
|
||||
|
||||
block_conf_tmpl = [
|
||||
"mixer.rotary_emb",
|
||||
"mixer.Wqkv",
|
||||
"mixer.inner_attn",
|
||||
"mixer.inner_cross_attn",
|
||||
"mixer.out_proj",
|
||||
# "dropout1", # skip when dropout_selective_checkpoint is True
|
||||
# "dropout2", # skip when dropout_selective_checkpoint is True
|
||||
"norm1",
|
||||
"norm2",
|
||||
"mlp.w1",
|
||||
"mlp.w2",
|
||||
"mlp.w3",
|
||||
]
|
||||
|
||||
block_conf = []
|
||||
for block_id in range(num_blocks):
|
||||
block_conf += [f"blocks.{block_id}.{layer}" for layer in block_conf_tmpl]
|
||||
|
||||
# We don't need to care about whether the embedding, norm, and head layers exist in the model after partitioning.
|
||||
# If they don't exist, they will be automatically ignored when registering activation trace hooks.
|
||||
activation_conf = ["embedding", "norm", "head"] + block_conf
|
||||
|
||||
return activation_conf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
|
|
@ -635,32 +630,39 @@ if __name__ == "__main__":
|
|||
|
||||
return output
|
||||
|
||||
def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor:
|
||||
if _num_chunks > 1:
|
||||
_output = _input
|
||||
for _model_chunk in _model_chunks:
|
||||
_output = _model_chunk(_output)
|
||||
else:
|
||||
_output = _model_chunks(_input)
|
||||
|
||||
return _output
|
||||
|
||||
# num_chunks config
|
||||
_num_chunks = 1
|
||||
|
||||
# init model and optimizer
|
||||
_model: torch.nn.Module = SimpleModel()
|
||||
if _num_chunks > 1:
|
||||
_chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)]
|
||||
_model = torch.nn.ModuleList(_chunks).cuda()
|
||||
else:
|
||||
_model: torch.nn.Module = SimpleModel().cuda()
|
||||
_optimizer = torch.optim.Adam(_model.parameters())
|
||||
|
||||
# create activation config for simple model layer by layer.
|
||||
activation_configs = [
|
||||
# model level 0
|
||||
"layer1",
|
||||
"layer2",
|
||||
"layer3",
|
||||
# model level 1
|
||||
"layer2.layer1",
|
||||
"layer2.layer3",
|
||||
]
|
||||
|
||||
_model.modules()
|
||||
|
||||
# init profiler
|
||||
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler.log", activation_configs)
|
||||
profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler", total_steps=1)
|
||||
|
||||
_optimizer.zero_grad()
|
||||
|
||||
x1 = torch.randn((128, 5120))
|
||||
x2 = torch.randn((128, 5120))
|
||||
out1 = _model(x1)
|
||||
out2 = _model(x2)
|
||||
# inputs
|
||||
x1 = torch.randn((128, 5120)).cuda()
|
||||
x2 = torch.randn((128, 5120)).cuda()
|
||||
# forward
|
||||
out1 = _simple_schedule(_num_chunks, _model, x1)
|
||||
out2 = _simple_schedule(_num_chunks, _model, x2)
|
||||
# backward
|
||||
out1.mean().backward()
|
||||
out2.mean().backward()
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,6 @@ from asyncio.tasks import ALL_COMPLETED
|
|||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Union
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -24,6 +22,13 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
try:
|
||||
import boto3
|
||||
import botocore
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
|
||||
|
|
@ -234,13 +239,13 @@ class Boto3Client(StorageClient):
|
|||
"""
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
for obj in page["Contents"]:
|
||||
fp: str = obj["Key"]
|
||||
folder_name_list.append(fp.rsplit("/", maxsplit=1)[1])
|
||||
return folder_name_list
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
pth: str = obj["Key"]
|
||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
|
||||
@staticmethod
|
||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||
|
|
@ -391,6 +396,11 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
self.tmp_local_folder = tmp_local_folder
|
||||
self.async_mode = async_mode
|
||||
self.has_warning = False
|
||||
self._async_loop = None
|
||||
self._thread_pool = None
|
||||
self.latest_save_folder = None
|
||||
self.latest_save_step = 0
|
||||
self.async_task_peeding = False
|
||||
|
||||
if enable_save and self.async_mode:
|
||||
self._async_loop = asyncio.new_event_loop()
|
||||
|
|
@ -485,6 +495,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
||||
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
self.async_task_peeding = True
|
||||
else:
|
||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||
self.upload_count += 1
|
||||
|
|
@ -523,23 +534,22 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
pass
|
||||
|
||||
async def _sync_tasks(self) -> Awaitable[None]:
|
||||
if not self._async_stack:
|
||||
return
|
||||
|
||||
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
|
||||
|
||||
for task in self._async_stack:
|
||||
try:
|
||||
task.exception()
|
||||
except InvalidStateError:
|
||||
continue
|
||||
except Exception as e:
|
||||
file_id = len(self._exception_list)
|
||||
self._exception_list.append((e, file_id))
|
||||
|
||||
logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}")
|
||||
|
||||
self._async_stack.clear()
|
||||
if self._async_stack:
|
||||
await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED)
|
||||
count = 0
|
||||
while self._async_stack:
|
||||
t = self._async_stack[0]
|
||||
try:
|
||||
e = t.exception()
|
||||
if e:
|
||||
self._exception_list.append((e, count))
|
||||
logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}")
|
||||
# raise e
|
||||
count += 1
|
||||
self._async_stack.pop(0)
|
||||
except InvalidStateError:
|
||||
# Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception
|
||||
pass
|
||||
|
||||
def async_executor(self, fn: Callable, *args, **kwargs) -> None:
|
||||
"""
|
||||
|
|
@ -559,11 +569,14 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
if not self.async_mode:
|
||||
return
|
||||
|
||||
if not self.async_task_peeding:
|
||||
return
|
||||
|
||||
if self._async_loop:
|
||||
self._async_loop.run_until_complete(self._sync_tasks())
|
||||
|
||||
if self._exception_list:
|
||||
for file_id, error_msg in self._exception_list:
|
||||
for error_msg, file_id in self._exception_list:
|
||||
logger.error(
|
||||
f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} "
|
||||
f"failed on step {self.upload_count}: {error_msg}"
|
||||
|
|
@ -577,10 +590,16 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
self._del_tmp_folder()
|
||||
self._exception_list.clear()
|
||||
self._to_be_del_files.clear()
|
||||
self.async_task_peeding = False
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("all async uploads succeeded!")
|
||||
self.upload_count += 1
|
||||
if self.async_mode:
|
||||
self.save(
|
||||
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||
saved_obj=dict({"step": self.latest_save_step}),
|
||||
async_upload=False,
|
||||
)
|
||||
|
||||
|
||||
storage_manager: StorageManager = None
|
||||
|
|
|
|||
|
|
@ -11,10 +11,6 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
def copy_ignore_folder(source_path, target_path):
|
||||
os.system(f"cp -r {source_path}/* {target_path}/")
|
||||
|
||||
|
||||
def tb_save_run_info(writer, config_lines, global_step=0):
|
||||
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
|
||||
lines = []
|
||||
|
|
@ -44,7 +40,8 @@ def init_tb_writer(
|
|||
if gpc.get_global_rank() == 0:
|
||||
if resume_tb_folder is not None:
|
||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
|
||||
copy_ignore_folder(resume_tb_folder, tb_folder)
|
||||
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
||||
os.system(f"chmod -R +w {tb_folder}/")
|
||||
else:
|
||||
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
||||
|
||||
|
|
|
|||
623
train.py
623
train.py
|
|
@ -5,99 +5,48 @@ import socket
|
|||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.scheduler import SchedulerMetricHook
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||
from internlm.data.collaters import jsonl_ds_collate_fn, packed_collate_fn
|
||||
from internlm.data.dataset import get_dataset_dict
|
||||
from internlm.data.dummy_dataset import RandomDataset
|
||||
from internlm.data.packed_dataset import (
|
||||
PackedDataset,
|
||||
PackedDatasetWithoutCuSeqlen,
|
||||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.model.moe import create_moe_param_groups, has_moe_layers
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message, set_env_var
|
||||
from internlm.monitor import initialize_monitor_manager, send_alert_message
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.train import (
|
||||
get_train_data_loader,
|
||||
get_validation_data_loader,
|
||||
initialize_distributed_env,
|
||||
initialize_llm_profile,
|
||||
initialize_model,
|
||||
initialize_optimizer,
|
||||
load_new_batch,
|
||||
record_current_batch_training_metrics,
|
||||
)
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
get_master_node,
|
||||
get_megatron_flops,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import (
|
||||
CheckpointSaveManager,
|
||||
load_context,
|
||||
load_model_checkpoint,
|
||||
load_optimizer_checkpoint,
|
||||
load_sampler,
|
||||
load_scheduler,
|
||||
)
|
||||
from internlm.utils.parallel import (
|
||||
get_parallel_log_file_name,
|
||||
is_no_pp_or_last_stage,
|
||||
sync_model_param_with_ep,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.simple_memory_profiler import (
|
||||
SimpleMemoryProfiler,
|
||||
build_activation_config,
|
||||
)
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
from internlm.utils.parallel import get_parallel_log_file_name
|
||||
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
|
||||
from internlm.utils.writer import Writer
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_distributed_env(config: str, launcher: str = "slurm", master_port: int = 8888, seed: int = 1024):
|
||||
"""
|
||||
Initialize distributed environment for distributed training.
|
||||
|
||||
Args:
|
||||
config (str): Config file path.
|
||||
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
||||
master_port (str): The master port for distributed training. 8888 by default.
|
||||
seed (int, optional): Specified random seed for every process. 1024 by default.
|
||||
"""
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if launcher == "torch":
|
||||
internlm.launch_from_torch(config=config, seed=seed)
|
||||
elif launcher == "slurm":
|
||||
internlm.launch_from_slurm(
|
||||
config=config,
|
||||
host=get_master_node(),
|
||||
port=master_port,
|
||||
seed=seed,
|
||||
)
|
||||
else:
|
||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||
|
||||
|
||||
def initialize_llm_logger(start_time: str):
|
||||
"""
|
||||
Initialize customed uniscale logger.
|
||||
|
|
@ -118,338 +67,14 @@ def initialize_llm_logger(start_time: str):
|
|||
return uniscale_logger
|
||||
|
||||
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
||||
Returns: The neural network model to be trained or evaluated.
|
||||
"""
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model))
|
||||
if isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList(
|
||||
[
|
||||
NaiveAMPModel(
|
||||
model=_m,
|
||||
output_to_fp32=False, # manually controlled by interleaved pipleline scheduler
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
for _m in model
|
||||
]
|
||||
)
|
||||
else:
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=is_no_pp_or_last_stage(),
|
||||
dtype=gpc.config.model.get("dtype", torch.half),
|
||||
sync_buffer=False,
|
||||
)
|
||||
|
||||
# This sync is very important, cause the model weights kept in optimizer are copied
|
||||
# from the origin parameters in the memory, so we should make sure the dp sync
|
||||
# does not influence the model weights in optimizer be different with the origin parameters.
|
||||
sync_model_param_with_ep(model)
|
||||
|
||||
# This function is needed to make sure parameters that are not splitted by tensor parallelism are
|
||||
# the same across tensor parallelism.
|
||||
sync_model_param_within_tp(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_train_data_loader(num_worker: int = 0):
|
||||
"""
|
||||
Generate and return the training data loader.
|
||||
|
||||
Returns: A tuple of (train_dl, dataset_types).
|
||||
"""
|
||||
|
||||
# Get the dataset types
|
||||
dataset_types = None
|
||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
# Get the sample weight dictionary
|
||||
train_folder = data_cfg.train_folder
|
||||
|
||||
if not train_folder:
|
||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||
if data_cfg.pack_sample_into_one:
|
||||
train_ds = PackedDatasetWithoutCuSeqlen(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = PackedDataset(
|
||||
train_ds, max_length_per_sample=data_cfg.seq_len, packed_length=data_cfg.packed_length
|
||||
)
|
||||
else:
|
||||
train_ds = get_packed_dataset_without_short_length(
|
||||
folder=data_cfg.train_folder,
|
||||
packed_length=data_cfg.packed_length,
|
||||
max_length_per_sample=data_cfg.seq_len,
|
||||
show_progress=dist.get_rank() == 0,
|
||||
min_length=data_cfg.min_length,
|
||||
min_length_dict=data_cfg.get("min_length_dict", {}),
|
||||
pack_into_one_sample=data_cfg.pack_sample_into_one,
|
||||
)
|
||||
|
||||
# partition already completed
|
||||
# assert isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen))
|
||||
if isinstance(train_ds, (PackedDataset, PackedDatasetWithoutCuSeqlen)):
|
||||
datasets = [train_ds]
|
||||
else:
|
||||
datasets = train_ds.datasets
|
||||
|
||||
# Create the training dataset sampler
|
||||
train_sampler = StaticBatchSampler(
|
||||
datasets,
|
||||
batch_size=data_cfg.micro_num,
|
||||
rampup_batch_size=data_cfg.rampup_batch_size,
|
||||
micro_bsz=data_cfg.micro_bsz,
|
||||
seed=1024,
|
||||
drop_last=True,
|
||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
data_world_size=gpc.get_world_size(ParallelMode.DATA),
|
||||
)
|
||||
|
||||
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
|
||||
|
||||
# Create the training data loader
|
||||
train_dl = DataLoader(
|
||||
dataset=train_ds,
|
||||
batch_sampler=train_sampler,
|
||||
num_workers=num_worker,
|
||||
pin_memory=True,
|
||||
collate_fn=train_collate_fn,
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
return train_dl, dataset_types
|
||||
|
||||
|
||||
def get_validation_data_loader(num_worker: int = 0):
|
||||
"""Generate and return the validation data loader."""
|
||||
|
||||
data_cfg = gpc.config.data
|
||||
|
||||
if not data_cfg.valid_folder:
|
||||
val_ds = RandomDataset(num_samples=gpc.get_world_size(ParallelMode.DATA) * 500, max_len=data_cfg.seq_len)
|
||||
else:
|
||||
val_ds = get_dataset_dict(folder=data_cfg.valid_folder, split="")
|
||||
|
||||
if not isinstance(val_ds, dict):
|
||||
val_ds = {"val": val_ds}
|
||||
|
||||
val_collate_fn = partial(jsonl_ds_collate_fn, max_length_per_sample=data_cfg.seq_len)
|
||||
|
||||
val_dls = {}
|
||||
for val_name, ds in val_ds.items():
|
||||
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
|
||||
# otherwise too much data may be dropped
|
||||
batch_size = min(
|
||||
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
|
||||
)
|
||||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.") # pylint: disable=W1203
|
||||
continue
|
||||
|
||||
val_dls[val_name] = get_dpsampler_dataloader(
|
||||
ds, shuffle=False, num_workers=num_worker, batch_size=batch_size, collate_fn=val_collate_fn, drop_last=True
|
||||
) # drop_last=True, otherwise it may cause problems in the last batch
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"load validation dataset {val_name} with valid batch size {str(batch_size)} and "
|
||||
f"samples {str(len(val_dls[val_name]))}."
|
||||
)
|
||||
|
||||
return val_dls
|
||||
|
||||
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
||||
Args:
|
||||
train_dl (torch.utils.data.DataLoader): Dataloader for training.
|
||||
train_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
||||
train_state (TrainState): Current training state.
|
||||
|
||||
Returns: A batch data and the updated train_iter.
|
||||
"""
|
||||
|
||||
timer("batch-gen").start()
|
||||
try:
|
||||
batch = next(train_iter) # structure is ({'input_ids': Tensor, 'cu_seqlens': Tensor}, Tensor)
|
||||
next(train_state.batch_sampler_iter)
|
||||
except StopIteration:
|
||||
train_iter = iter(train_dl)
|
||||
batch = next(train_iter)
|
||||
train_state.batch_sampler_iter = iter(train_state.batch_sampler)
|
||||
next(train_state.batch_sampler_iter)
|
||||
train_state.num_consumed_samples_in_epoch = 0
|
||||
timer("batch-gen").stop()
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
||||
def initialize_optimizer(model: nn.Module):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Your model instance to be trained or evaluated.
|
||||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
|
||||
adam_cfg = gpc.config.adam
|
||||
if gpc.config.model.num_experts > 1:
|
||||
params = create_moe_param_groups(model, adam_cfg.weight_decay)
|
||||
else:
|
||||
params = [{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}]
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=params,
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
)
|
||||
|
||||
has_moe = has_moe_layers(model)
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
has_moe=has_moe,
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
||||
lr_scheduler = FineTuneCosineAnnealingWarmupLR(optimizer, **gpc.config.lr_scheduler)
|
||||
|
||||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
writer,
|
||||
success_update,
|
||||
batch_count,
|
||||
batch,
|
||||
train_state,
|
||||
optimizer,
|
||||
beta2_scheduler,
|
||||
trainer,
|
||||
start_time,
|
||||
loss,
|
||||
grad_norm,
|
||||
metric,
|
||||
update_panel,
|
||||
):
|
||||
"""
|
||||
Print some training metrics of current batch.
|
||||
"""
|
||||
|
||||
set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time()))
|
||||
|
||||
if success_update in (0, True):
|
||||
train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA)
|
||||
if is_no_pp_or_last_stage():
|
||||
acc_perplex = metric.get_metric()
|
||||
|
||||
if success_update and gpc.is_rank_for_log():
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.grad_scaler._scale.item()
|
||||
elif hasattr(trainer.engine.optimizer.optim, "grad_scaler"):
|
||||
scaler = trainer.engine.optimizer.optim.grad_scaler._scale.item()
|
||||
|
||||
num_tokens_in_batch = batch[1].nelement()
|
||||
num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
|
||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
|
||||
tk_per_gpu = 0
|
||||
tk_per_gpu = round(
|
||||
num_tokens_in_batch
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
/ gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
/ (time.time() - start_time),
|
||||
2,
|
||||
)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
}
|
||||
|
||||
infos["micro_num"] = len(batch[1])
|
||||
infos["num_consumed_tokens"] = train_state.num_consumed_tokens
|
||||
infos["inf_nan_skip_batches"] = train_state.inf_nan_skip_batches
|
||||
infos["num_samples_in_batch"] = num_samples_in_batch # the number of batches which have the most samples
|
||||
infos["largest_length"] = max_length_in_batch # the longest input
|
||||
infos["largest_batch"] = max_samples_in_batch # the batch with the most samples
|
||||
infos["smallest_batch"] = min_samples_in_batch
|
||||
infos["adam_beta2"] = beta2_scheduler.get_beta2()
|
||||
|
||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||
infos["fwd_bwd_time"] = fwd_bwd_time
|
||||
|
||||
for key, value in acc_perplex.items():
|
||||
infos[key] = value
|
||||
|
||||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value} "
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
if update_panel:
|
||||
logger.info(
|
||||
line,
|
||||
extra={
|
||||
"step": batch_count,
|
||||
"lr": lr,
|
||||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"grad_norm": grad_norm,
|
||||
"loss": loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info(line)
|
||||
|
||||
# if loss spike occurs, send alert info to feishu
|
||||
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
||||
|
||||
|
||||
def main(args):
|
||||
# init setting
|
||||
skip_batches = gpc.config.data.skip_batches
|
||||
total_steps = gpc.config.data.total_steps
|
||||
valid_every = gpc.config.data.valid_every
|
||||
load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
||||
load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None)
|
||||
load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None)
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
|
|
@ -485,32 +110,19 @@ def main(args):
|
|||
enable_tb=gpc.config.enable_tb,
|
||||
)
|
||||
|
||||
model_load_path = None
|
||||
if load_resume_ckpt_folder is not None:
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_resume_ckpt_folder
|
||||
elif load_model_only_folder is not None:
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"===========SFT training from `{load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = load_model_only_folder
|
||||
else:
|
||||
logger.info( # pylint: disable=W1203
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
||||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
model_config=gpc.config.model,
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
||||
|
|
@ -520,30 +132,12 @@ def main(args):
|
|||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=model)
|
||||
ckpt_manager.try_load_model(current_time)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
# Loading other persistent training states.
|
||||
if load_resume_ckpt_folder is not None:
|
||||
# load lr scheduler states.
|
||||
load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(load_resume_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler)
|
||||
# load optimzier states.
|
||||
if load_optimizer:
|
||||
load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer)
|
||||
|
||||
ckpt_save_manager = CheckpointSaveManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
model_config=gpc.config.model,
|
||||
)
|
||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
|
||||
# initialize metric for calculating accuracy and perplexity
|
||||
metric = AccPerplex(
|
||||
|
|
@ -579,12 +173,11 @@ def main(args):
|
|||
# initialize simple memory profiler
|
||||
if args.profiling:
|
||||
memory_profiler = SimpleMemoryProfiler(
|
||||
model.model,
|
||||
model,
|
||||
optimizer.optim,
|
||||
log_folder=f"memory_trace/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}",
|
||||
activation_config=build_activation_config(gpc.config.model.num_layers),
|
||||
)
|
||||
else:
|
||||
memory_profiler = None
|
||||
|
|
@ -597,86 +190,85 @@ def main(args):
|
|||
# transfer the train data loader into train data iterator
|
||||
train_iter = iter(train_dl)
|
||||
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
# load batch data
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
# process data
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
|
||||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
_, _, loss, moe_loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||
)
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
writer=writer,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
train_state=train_state,
|
||||
optimizer=optimizer,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
moe_loss=moe_loss,
|
||||
grad_norm=np.array(grad_norm_groups),
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
# record the consumed samples in training
|
||||
train_state.batch_count = batch_count
|
||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||
if batch_skipper(batch_count): # skip this batch
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Skip batch count:`{batch_count}`...") # pylint: disable=W1203
|
||||
timer("one-batch").stop()
|
||||
continue
|
||||
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
type_ids = batch[0].pop("type_ids", None)
|
||||
# process data
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
|
||||
if type_ids is not None:
|
||||
metric.set_current_type_ids(type_ids=type_ids)
|
||||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=False,
|
||||
moe_loss_coeff=gpc.config.loss.moe_loss_coeff,
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
trainer_result = trainer.step()
|
||||
assert trainer_result is not None
|
||||
|
||||
success_update, grad_norm_groups = trainer_result
|
||||
if success_update: # update parameters successfully
|
||||
train_state.step_count += 1
|
||||
else:
|
||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||
)
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
writer=writer,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
train_state=train_state,
|
||||
optimizer=optimizer,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
trainer=trainer,
|
||||
start_time=start_time,
|
||||
loss=loss,
|
||||
grad_norm=np.array(grad_norm_groups),
|
||||
metric=metric,
|
||||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
timer("one-batch").stop()
|
||||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
with switch_sequence_parallel_mode():
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
|
|
@ -686,14 +278,19 @@ def main(args):
|
|||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
now_break = ckpt_manager.try_save_checkpoint(train_state)
|
||||
if now_break:
|
||||
break
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
ckpt_save_manager.try_save_checkpoint(train_state)
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
ckpt_save_manager.wait_async_upload_finish()
|
||||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
ckpt_manager.wait_async_upload_finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue