feat(train.py): support torch profiler (#201)

* feat(train.py): support torch profiling

* feat(train.py): optimize initialize_llm_profile

* feat(train.py): profiling with tp0 and dp0

* move sequence parallel context manager to evalation func

* fix lint

* move the process for type_ids to load_new_batch

* fix lint

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
pull/216/head^2
huangting4201 2023-08-21 15:23:38 +08:00 committed by GitHub
parent 4832671abe
commit 53648dc0e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 230 additions and 181 deletions

View File

@ -218,3 +218,21 @@ def get_megatron_flops(
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12)) tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops return tflops
class DummyProfile:
"""
Dummy Profile.
"""
def __init__(self, *args, **kwargs) -> None:
pass
def __enter__(self):
return self
def __exit__(self, a, b, c):
pass
def step(self):
pass

View File

@ -50,6 +50,16 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
trainer.schedule._hooks = prev_metric_hooks trainer.schedule._hooks = prev_metric_hooks
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.model.sequence_parallel
try:
gpc.config.model.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode
def evaluate_on_val_dls( def evaluate_on_val_dls(
trainer, trainer,
val_dls, val_dls,
@ -58,6 +68,7 @@ def evaluate_on_val_dls(
step_count, step_count,
update_panel: bool = False, update_panel: bool = False,
): ):
with switch_sequence_parallel_mode():
torch.cuda.empty_cache() torch.cuda.empty_cache()
trainer.eval() trainer.eval()
verbose = gpc.is_rank_for_log() verbose = gpc.is_rank_for_log()
@ -154,13 +165,3 @@ def evaluate_on_val_dls(
trainer.train() trainer.train()
torch.cuda.empty_cache() torch.cuda.empty_cache()
dist.barrier() dist.barrier()
@contextmanager
def switch_sequence_parallel_mode():
prev_mode = gpc.config.model.sequence_parallel
try:
gpc.config.model.sequence_parallel = False
yield
finally:
gpc.config.model.sequence_parallel = prev_mode

View File

@ -38,12 +38,13 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.utils.common import ( from internlm.utils.common import (
BatchSkipper, BatchSkipper,
DummyProfile,
get_master_node, get_master_node,
get_megatron_flops, get_megatron_flops,
launch_time, launch_time,
parse_args, parse_args,
) )
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import ( from internlm.utils.model_checkpoint import (
@ -292,6 +293,11 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
train_state.num_consumed_samples_in_epoch = 0 train_state.num_consumed_samples_in_epoch = 0
timer("batch-gen").stop() timer("batch-gen").stop()
if batch[0].get("type_ids", None) is not None:
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
return batch, train_iter return batch, train_iter
@ -323,6 +329,29 @@ def initialize_optimizer(model: nn.Module):
return optimizer, beta2_scheduler, lr_scheduler return optimizer, beta2_scheduler, lr_scheduler
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
"""Initialize and return the profiler context manager instance."""
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
llm_profile = torch.profiler.profile
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
else:
llm_profile = DummyProfile
return llm_profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
),
with_stack=True,
with_modules=True,
)
def record_current_batch_training_metrics( def record_current_batch_training_metrics(
get_tflops_func, get_tflops_func,
logger, logger,
@ -587,6 +616,7 @@ def main(args):
# transfer the train data loader into train data iterator # transfer the train data loader into train data iterator
train_iter = iter(train_dl) train_iter = iter(train_dl)
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
# start iterating the train data and begin training # start iterating the train data and begin training
for batch_count in range(train_state.batch_count, total_steps): for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0: if batch_count % 50 == 0:
@ -609,17 +639,15 @@ def main(args):
# zero the grads of parameters # zero the grads of parameters
trainer.zero_grad() trainer.zero_grad()
type_ids = batch[0].pop("type_ids", None)
# process data # process data
# if use_flash_attn is False, we need to unpack type_ids if batch[0].get("type_ids", None) is not None:
if not gpc.config.model.use_flash_attn: metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
if type_ids is not None:
metric.set_current_type_ids(type_ids=type_ids)
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) _, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
timer("fwd-bwd").stop() timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm) # update parameters, and returns (success_update, grad_norm)
@ -634,7 +662,8 @@ def main(args):
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.") logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message( send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}." address=gpc.config.alert_address,
message=f"Warning: skip parameter update at step {batch_count}.",
) )
# calculate and record the training metrics, eg. loss, accuracy and so on. # calculate and record the training metrics, eg. loss, accuracy and so on.
@ -660,7 +689,6 @@ def main(args):
# evaluate on validation data loaders # evaluate on validation data loaders
if valid_every > 0 and train_state.step_count % valid_every == 0: if valid_every > 0 and train_state.step_count % valid_every == 0:
with switch_sequence_parallel_mode():
evaluate_on_val_dls( evaluate_on_val_dls(
trainer=trainer, trainer=trainer,
val_dls=val_dls, val_dls=val_dls,
@ -670,13 +698,15 @@ def main(args):
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,
) )
if memory_profiler is not None:
memory_profiler.step()
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
# # save batch sampler that tracks the true consumed samples # # save batch sampler that tracks the true consumed samples
ckpt_save_manager.try_save_checkpoint(train_state) ckpt_save_manager.try_save_checkpoint(train_state)
if memory_profiler is not None:
memory_profiler.step()
prof.step()
ckpt_save_manager.wait_async_upload_finish() ckpt_save_manager.wait_async_upload_finish()