1. fix(config): rampup_batch_size defalut value BC. (#515)

2. fix(config): standardize config parameter access.
3. feat(launch): add warmup_process_group
4. feat(memory): add cuda_memory_analyze
pull/523/head
Guoteng 2023-11-28 19:33:46 +08:00 committed by GitHub
parent 06e8301861
commit 757e19e01a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 102 additions and 21 deletions

View File

@ -44,8 +44,8 @@ ckpt = dict(
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
TRAIN_FOLDER = None # "/path/to/dataset"
VALID_FOLDER = None # "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
@ -64,12 +64,12 @@ data = dict(
# each increment. For example, "192 24 8" means that the batch size (micro_num)
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
# (IMPORTANT): The interval step size is 'micro_bsz'.
rampup_batch_size=None,
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
train_folder=TRAIN_FOLDER,
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
)

View File

@ -35,19 +35,19 @@ def get_tensor_shape():
if gpc.config.parallel.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"] // sequence_world_size,
gpc.config.model["hidden_size"],
)
else:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
gpc.config.data["seq_len"] * gpc.config.data["micro_bsz"],
gpc.config.model["hidden_size"],
)
else:
tensor_shape = (
gpc.config.data["micro_bsz"],
gpc.config.SEQ_LEN,
gpc.config.HIDDEN_SIZE,
gpc.config.data["seq_len"],
gpc.config.model["hidden_size"],
)
return tensor_shape
else:

View File

@ -13,6 +13,7 @@ from internlm.core.context import Config
from internlm.core.context import global_context as gpc
from internlm.monitor import initialize_light_monitor
from internlm.utils.common import get_master_node
from internlm.utils.gputest import warmup_process_group
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
@ -60,6 +61,9 @@ def get_default_parser():
def args_sanity_check():
assert gpc.config is not None, "config is not load!"
if "JOB_NAME" not in gpc.config:
gpc.config._add_item("JOB_NAME", "AnonymousJob")
# the default model type is INTERNLM
if "model_type" not in gpc.config:
gpc.config._add_item("model_type", "INTERNLM")
@ -144,10 +148,6 @@ def args_sanity_check():
if "diag_outlier_ratio" not in data:
data._add_item("diag_outlier_ratio", 1.1)
if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0:
bsz = data.micro_num
data._add_item("rampup_batch_size", f"{bsz} {bsz} 1")
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
if gpc.is_rank_for_log():
@ -423,6 +423,8 @@ def launch(
gpc.set_seed(seed)
warmup_process_group()
if gpc.is_rank_for_log():
logger.info(
f"Distributed environment is initialized, "

View File

@ -101,7 +101,7 @@ def evaluate_on_val_dls(
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.model["hidden_size"]]
)
with switch_evaluation_pipeline_scheduler(

View File

@ -27,10 +27,17 @@ from internlm.utils.common import get_current_device
logger = get_logger(__file__)
# Gloabl cuda cache flush counter
n_caching_allocator_flushes = 0
def empty_cache_and_diag(batch_count, interval=50):
"""empty cuda cache and run diag bench or tests."""
if interval <= 0:
interval = 50
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
if batch_count % int(interval) == 0:
# there is no need to do diag on the first batch
if batch_count > 0:
@ -259,3 +266,75 @@ def bench_gpu(use_flash_attn=True):
address=gpc.config.monitor.alert.feishu_alert_address,
message=msg,
)
"""
Useful utility functions migrated from deepseped.
"""
def warmup_process_group():
# Prevent OOM from nccl communication.
if dist.is_initialized():
buffer = torch.ones([64]).cuda()
if gpc.is_initialized(ParallelMode.DATA):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.DATA))
if gpc.is_initialized(ParallelMode.TENSOR):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.TENSOR))
if gpc.is_initialized(ParallelMode.PIPELINE):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.PIPELINE))
if gpc.is_initialized(ParallelMode.ZERO1):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO1))
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.MODEL))
if gpc.is_initialized(ParallelMode.ZERO3_DP):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.ZERO3_DP))
if gpc.is_initialized(ParallelMode.EXPERT_DATA):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT_DATA))
if gpc.is_initialized(ParallelMode.EXPERT):
dist.all_reduce(buffer, group=gpc.get_group(ParallelMode.EXPERT))
dist.barrier()
del buffer
torch.cuda.empty_cache()
def cuda_memory_analyze(step=0, print_mm_suage=False):
global n_caching_allocator_flushes
torch.cuda.synchronize()
g_rank = gpc.get_global_rank()
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
rank_id = f"Rank:{g_rank}-tp{tp_rank}-pp{pp_rank}-dp{dp_rank}"
if print_mm_suage and gpc.get_local_rank(ParallelMode.DATA) == 0:
logger.info(
f"{rank_id}: Step {step}: "
f"Allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),4 )} GB, "
f"Max_Allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),4)} GB, "
f"Reserved {round(torch.cuda.memory_reserved()/ (1024 * 1024 * 1024),4)} GB, "
f"Max_Reserved {round(torch.cuda.max_memory_reserved()/ (1024 * 1024 * 1024),4)} GB "
)
torch.cuda.reset_peak_memory_stats()
# warn user about caching allocator flushes
memory_stats = torch.cuda.memory_stats()
alloc_retries = memory_stats.get("num_alloc_retries")
if alloc_retries is None:
alloc_retries = 0
if alloc_retries > n_caching_allocator_flushes:
retry_count = alloc_retries - n_caching_allocator_flushes
if gpc.get_global_rank() == 0:
logger.warning(
f"{rank_id}: pytorch allocator cache flushes {retry_count} times since last step."
"this happens when there is high memory pressure and is detrimental to "
"performance. if this is happening frequently consider adjusting "
"settings to reduce memory consumption. If you are unable to "
"make the cache flushes go away consider adding "
"torch.cuda.empty_cache() calls in your training loop to ensure "
"that all ranks flush their caches at the same time"
)
n_caching_allocator_flushes = alloc_retries

View File

@ -106,13 +106,13 @@ def main(args):
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.SEQ_LEN,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
mlp_ratio=gpc.config.model["mlp_ratio"],
)
# get and broadcast current time

View File

@ -77,13 +77,13 @@ def main(args):
get_tflops_func = partial(
get_megatron_flops,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.SEQ_LEN,
seq_len=gpc.config.data["seq_len"],
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
mlp_ratio=gpc.config.model["mlp_ratio"],
)
# get and broadcast current time