fix(moe): fix moe compatibility for fsdp and memory profiling (#417)

* fix moe compatibility for fsdp and memory profiling

* update moe config
pull/418/head
Wenwen Qu 2023-10-17 14:13:48 +08:00 committed by GitHub
parent 37e0c86e5a
commit eeef07934a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 10 deletions

View File

@ -4,7 +4,7 @@ DO_ALERT = False
SEQ_LEN = 2048 SEQ_LEN = 2048
HIDDEN_SIZE = 4096 HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32 NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3 MLP_RATIO = 4 / 3
NUM_LAYER = 32 NUM_LAYER = 32
VOCAB_SIZE = 103168 VOCAB_SIZE = 103168
@ -30,6 +30,14 @@ ckpt = dict(
# 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" # 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported. # 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported.
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
# path specified in `load_ckpt_info` by default.
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
auto_resume=True,
checkpoint_every=CHECKPOINT_EVERY, checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
@ -43,7 +51,7 @@ data = dict(
# micro_num means the number of micro_batch contained in one gradient update # micro_num means the number of micro_batch contained in one gradient update
micro_num=4, micro_num=4,
# packed_length = micro_bsz * SEQ_LEN # packed_length = micro_bsz * SEQ_LEN
micro_bsz=1, micro_bsz=2,
# defaults to the value of micro_num # defaults to the value of micro_num
valid_micro_num=4, valid_micro_num=4,
# defaults to 0, means disable evaluate # defaults to 0, means disable evaluate
@ -81,8 +89,8 @@ grad_scaler = dict(
hybrid_zero_optimizer = dict( hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication # Enable low_level_optimzer overlap_communication
overlap_sync_grad=True, overlap_sync_grad=False,
overlap_sync_param=True, overlap_sync_param=False,
# bucket size for nccl communication params # bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024, reduce_bucket_size=512 * 1024 * 1024,
# grad clipping # grad clipping
@ -133,7 +141,7 @@ model = dict(
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
use_flash_attn=True, use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=4, num_experts=8,
moe_use_residual=False, moe_use_residual=False,
moe_gate_k=2, moe_gate_k=2,
) )
@ -150,8 +158,8 @@ pipeline parallel (dict):
tensor parallel: tensor parallel size, usually the number of GPUs per node. tensor parallel: tensor parallel size, usually the number of GPUs per node.
""" """
parallel = dict( parallel = dict(
zero1=-1, zero1=dict(size=-1, fsdp=False),
tensor=2, tensor=1,
pipeline=dict(size=1, interleaved_overlap=True), pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False, sequence_parallel=False,
) )

View File

@ -349,7 +349,7 @@ def args_sanity_check():
assert ( assert (
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time" ), "not support overlap and moe at the same time"
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this" assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=-1 can fix this"
def launch( def launch(

View File

@ -424,7 +424,9 @@ class SimpleMemoryProfiler:
layer_name, output.element_size() * output.nelement(), flush=False layer_name, output.element_size() * output.nelement(), flush=False
) )
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None: def _activation_trace_hook_forward(
self, chunk_id: int, model: Any, inputs: Any, output: Any # pylint: disable=W0613
) -> None:
""" """
Hook function to trace the activation memory usage for a forward pass. Hook function to trace the activation memory usage for a forward pass.
@ -437,7 +439,6 @@ class SimpleMemoryProfiler:
None None
""" """
del model, inputs del model, inputs
assert isinstance(output, torch.Tensor), f"invalid output type: {type(output)}"
if self._stoped: if self._stoped:
return return