mirror of https://github.com/InternLM/InternLM
fix(moe): fix moe compatibility for fsdp and memory profiling (#417)
* fix moe compatibility for fsdp and memory profiling * update moe configpull/418/head
parent
37e0c86e5a
commit
eeef07934a
|
@ -4,7 +4,7 @@ DO_ALERT = False
|
|||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
MLP_RATIO = 4 / 3
|
||||
NUM_LAYER = 32
|
||||
VOCAB_SIZE = 103168
|
||||
|
||||
|
@ -30,6 +30,14 @@ ckpt = dict(
|
|||
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
|
||||
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
|
||||
# with an automatic restart mechanism upon training reboot.
|
||||
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
|
||||
# path specified in `load_ckpt_info` by default.
|
||||
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
|
||||
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
|
||||
auto_resume=True,
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
|
@ -43,7 +51,7 @@ data = dict(
|
|||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=1,
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
|
@ -81,8 +89,8 @@ grad_scaler = dict(
|
|||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_param=True,
|
||||
overlap_sync_grad=False,
|
||||
overlap_sync_param=False,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
# grad clipping
|
||||
|
@ -133,7 +141,7 @@ model = dict(
|
|||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
num_experts=4,
|
||||
num_experts=8,
|
||||
moe_use_residual=False,
|
||||
moe_gate_k=2,
|
||||
)
|
||||
|
@ -150,8 +158,8 @@ pipeline parallel (dict):
|
|||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||
"""
|
||||
parallel = dict(
|
||||
zero1=-1,
|
||||
tensor=2,
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=1,
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=False,
|
||||
)
|
||||
|
|
|
@ -349,7 +349,7 @@ def args_sanity_check():
|
|||
assert (
|
||||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||
), "not support overlap and moe at the same time"
|
||||
assert gpc.config.parallel.zero1 == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=-1 can fix this"
|
||||
|
||||
|
||||
def launch(
|
||||
|
|
|
@ -424,7 +424,9 @@ class SimpleMemoryProfiler:
|
|||
layer_name, output.element_size() * output.nelement(), flush=False
|
||||
)
|
||||
|
||||
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
|
||||
def _activation_trace_hook_forward(
|
||||
self, chunk_id: int, model: Any, inputs: Any, output: Any # pylint: disable=W0613
|
||||
) -> None:
|
||||
"""
|
||||
Hook function to trace the activation memory usage for a forward pass.
|
||||
|
||||
|
@ -437,7 +439,6 @@ class SimpleMemoryProfiler:
|
|||
None
|
||||
"""
|
||||
del model, inputs
|
||||
assert isinstance(output, torch.Tensor), f"invalid output type: {type(output)}"
|
||||
|
||||
if self._stoped:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue