mirror of https://github.com/InternLM/InternLM
1. fix(config): rampup_batch_size defalut value BC. (#515)
2. fix(config): standardize config parameter access. 3. feat(launch): add warmup_process_group 4. feat(memory): add cuda_memory_analyzepull/523/head
parent
06e8301861
commit
757e19e01a
|
@ -44,8 +44,8 @@ ckpt = dict(
|
|||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
TRAIN_FOLDER = "/path/to/dataset"
|
||||
VALID_FOLDER = "/path/to/dataset"
|
||||
TRAIN_FOLDER = None # "/path/to/dataset"
|
||||
VALID_FOLDER = None # "/path/to/dataset"
|
||||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
|
@ -64,12 +64,12 @@ data = dict(
|
|||
# each increment. For example, "192 24 8" means that the batch size (micro_num)
|
||||
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
|
||||
# (IMPORTANT): The interval step size is 'micro_bsz'.
|
||||
rampup_batch_size=None,
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
# valid_folder=VALID_FOLDER,
|
||||
empty_cache_and_diag_interval=10,
|
||||
train_folder=TRAIN_FOLDER,
|
||||
valid_folder=VALID_FOLDER,
|
||||
empty_cache_and_diag_interval=200,
|
||||
diag_outlier_ratio=1.1,
|
||||
)
|
||||
|
||||
|
|
|
@ -35,19 +35,19 @@ def get_tensor_shape():
|
|||
if gpc.config.parallel.sequence_parallel:
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||
gpc.config.model["hidden_size"],
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"],
|
||||
gpc.config.model["hidden_size"],
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.data["micro_bsz"],
|
||||
gpc.config.SEQ_LEN,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
gpc.config.data["seq_len"],
|
||||
gpc.config.model["hidden_size"],
|
||||
)
|
||||
return tensor_shape
|
||||
else:
|
||||
|
|
|
@ -13,6 +13,7 @@ from internlm.core.context import Config
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor import initialize_light_monitor
|
||||
from internlm.utils.common import get_master_node
|
||||
from internlm.utils.gputest import warmup_process_group
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
|
@ -60,6 +61,9 @@ def get_default_parser():
|
|||
def args_sanity_check():
|
||||
assert gpc.config is not None, "config is not load!"
|
||||
|
||||
if "JOB_NAME" not in gpc.config:
|
||||
gpc.config._add_item("JOB_NAME", "AnonymousJob")
|
||||
|
||||
# the default model type is INTERNLM
|
||||
if "model_type" not in gpc.config:
|
||||
gpc.config._add_item("model_type", "INTERNLM")
|
||||
|
@ -144,10 +148,6 @@ def args_sanity_check():
|
|||
if "diag_outlier_ratio" not in data:
|
||||
data._add_item("diag_outlier_ratio", 1.1)
|
||||
|
||||
if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0:
|
||||
bsz = data.micro_num
|
||||
data._add_item("rampup_batch_size", f"{bsz} {bsz} 1")
|
||||
|
||||
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
|
@ -423,6 +423,8 @@ def launch(
|
|||
|
||||
gpc.set_seed(seed)
|
||||
|
||||
warmup_process_group()
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"Distributed environment is initialized, "
|
||||
|
|
|
@ -101,7 +101,7 @@ def evaluate_on_val_dls(
|
|||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
|
||||
)
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
|
|
|
@ -27,10 +27,17 @@ from internlm.utils.common import get_current_device
|
|||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
# Gloabl cuda cache flush counter
|
||||
n_caching_allocator_flushes = 0
|
||||
|
||||
|
||||
def empty_cache_and_diag(batch_count, interval=50):
|
||||
"""empty cuda cache and run diag bench or tests."""
|
||||
if interval <= 0:
|
||||
interval = 50
|
||||
|
||||
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
|
||||
|
||||
if batch_count % int(interval) == 0:
|
||||
# there is no need to do diag on the first batch
|
||||
if batch_count > 0:
|
||||
|
@ -259,3 +266,75 @@ def bench_gpu(use_flash_attn=True):
|
|||
address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Useful utility functions migrated from deepseped.
|
||||
"""
|
||||
|
||||
|
||||
def warmup_process_group():
|
||||
# Prevent OOM from nccl communication.
|
||||
if dist.is_initialized():
|
||||
buffer = torch.ones([64]).cuda()
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.DATA))
|
||||
if gpc.is_initialized(ParallelMode.TENSOR):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.PIPELINE))
|
||||
if gpc.is_initialized(ParallelMode.ZERO1):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO1))
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.MODEL))
|
||||
if gpc.is_initialized(ParallelMode.ZERO3_DP):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO3_DP))
|
||||
if gpc.is_initialized(ParallelMode.EXPERT_DATA):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT_DATA))
|
||||
if gpc.is_initialized(ParallelMode.EXPERT):
|
||||
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT))
|
||||
|
||||
dist.barrier()
|
||||
del buffer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def cuda_memory_analyze(step=0, print_mm_suage=False):
|
||||
global n_caching_allocator_flushes
|
||||
torch.cuda.synchronize()
|
||||
|
||||
g_rank = gpc.get_global_rank()
|
||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
||||
rank_id = f"Rank:{g_rank}-tp{tp_rank}-pp{pp_rank}-dp{dp_rank}"
|
||||
|
||||
if print_mm_suage and gpc.get_local_rank(ParallelMode.DATA) == 0:
|
||||
logger.info(
|
||||
f"{rank_id}: Step {step}: "
|
||||
f"Allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),4 )} GB, "
|
||||
f"Max_Allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),4)} GB, "
|
||||
f"Reserved {round(torch.cuda.memory_reserved()/ (1024 * 1024 * 1024),4)} GB, "
|
||||
f"Max_Reserved {round(torch.cuda.max_memory_reserved()/ (1024 * 1024 * 1024),4)} GB "
|
||||
)
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# warn user about caching allocator flushes
|
||||
memory_stats = torch.cuda.memory_stats()
|
||||
alloc_retries = memory_stats.get("num_alloc_retries")
|
||||
if alloc_retries is None:
|
||||
alloc_retries = 0
|
||||
if alloc_retries > n_caching_allocator_flushes:
|
||||
retry_count = alloc_retries - n_caching_allocator_flushes
|
||||
if gpc.get_global_rank() == 0:
|
||||
logger.warning(
|
||||
f"{rank_id}: pytorch allocator cache flushes {retry_count} times since last step."
|
||||
"this happens when there is high memory pressure and is detrimental to "
|
||||
"performance. if this is happening frequently consider adjusting "
|
||||
"settings to reduce memory consumption. If you are unable to "
|
||||
"make the cache flushes go away consider adding "
|
||||
"torch.cuda.empty_cache() calls in your training loop to ensure "
|
||||
"that all ranks flush their caches at the same time"
|
||||
)
|
||||
n_caching_allocator_flushes = alloc_retries
|
||||
|
|
|
@ -106,13 +106,13 @@ def main(args):
|
|||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
seq_len=gpc.config.SEQ_LEN,
|
||||
seq_len=gpc.config.data["seq_len"],
|
||||
hidden_size=gpc.config.model.hidden_size,
|
||||
num_layers=gpc.config.model.num_layers,
|
||||
vocab_size=gpc.config.model.vocab_size,
|
||||
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
|
||||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
mlp_ratio=gpc.config.model["mlp_ratio"],
|
||||
)
|
||||
|
||||
# get and broadcast current time
|
||||
|
|
4
train.py
4
train.py
|
@ -77,13 +77,13 @@ def main(args):
|
|||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
seq_len=gpc.config.SEQ_LEN,
|
||||
seq_len=gpc.config.data["seq_len"],
|
||||
hidden_size=gpc.config.model.hidden_size,
|
||||
num_layers=gpc.config.model.num_layers,
|
||||
vocab_size=gpc.config.model.vocab_size,
|
||||
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
|
||||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
mlp_ratio=gpc.config.model["mlp_ratio"],
|
||||
)
|
||||
|
||||
# get and broadcast current time
|
||||
|
|
Loading…
Reference in New Issue