mirror of https://github.com/InternLM/InternLM
feat(train.py): support torch profiler (#201)
* feat(train.py): support torch profiling * feat(train.py): optimize initialize_llm_profile * feat(train.py): profiling with tp0 and dp0 * move sequence parallel context manager to evalation func * fix lint * move the process for type_ids to load_new_batch * fix lint --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/216/head^2
parent
4832671abe
commit
53648dc0e9
|
@ -218,3 +218,21 @@ def get_megatron_flops(
|
|||
|
||||
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||
return tflops
|
||||
|
||||
|
||||
class DummyProfile:
|
||||
"""
|
||||
Dummy Profile.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, a, b, c):
|
||||
pass
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
|
|
@ -50,6 +50,16 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
|||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.model.sequence_parallel
|
||||
try:
|
||||
gpc.config.model.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
||||
|
||||
|
||||
def evaluate_on_val_dls(
|
||||
trainer,
|
||||
val_dls,
|
||||
|
@ -58,6 +68,7 @@ def evaluate_on_val_dls(
|
|||
step_count,
|
||||
update_panel: bool = False,
|
||||
):
|
||||
with switch_sequence_parallel_mode():
|
||||
torch.cuda.empty_cache()
|
||||
trainer.eval()
|
||||
verbose = gpc.is_rank_for_log()
|
||||
|
@ -154,13 +165,3 @@ def evaluate_on_val_dls(
|
|||
trainer.train()
|
||||
torch.cuda.empty_cache()
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def switch_sequence_parallel_mode():
|
||||
prev_mode = gpc.config.model.sequence_parallel
|
||||
try:
|
||||
gpc.config.model.sequence_parallel = False
|
||||
yield
|
||||
finally:
|
||||
gpc.config.model.sequence_parallel = prev_mode
|
||||
|
|
56
train.py
56
train.py
|
@ -38,12 +38,13 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
|||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
DummyProfile,
|
||||
get_master_node,
|
||||
get_megatron_flops,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import (
|
||||
|
@ -292,6 +293,11 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
|
|||
train_state.num_consumed_samples_in_epoch = 0
|
||||
timer("batch-gen").stop()
|
||||
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
|
||||
|
||||
return batch, train_iter
|
||||
|
||||
|
||||
|
@ -323,6 +329,29 @@ def initialize_optimizer(model: nn.Module):
|
|||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||
"""Initialize and return the profiler context manager instance."""
|
||||
|
||||
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
llm_profile = torch.profiler.profile
|
||||
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
|
||||
else:
|
||||
llm_profile = DummyProfile
|
||||
|
||||
return llm_profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
|
||||
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
|
||||
),
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
)
|
||||
|
||||
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
|
@ -587,6 +616,7 @@ def main(args):
|
|||
# transfer the train data loader into train data iterator
|
||||
train_iter = iter(train_dl)
|
||||
|
||||
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||
# start iterating the train data and begin training
|
||||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
|
@ -609,17 +639,15 @@ def main(args):
|
|||
|
||||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
type_ids = batch[0].pop("type_ids", None)
|
||||
# process data
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
|
||||
if type_ids is not None:
|
||||
metric.set_current_type_ids(type_ids=type_ids)
|
||||
if batch[0].get("type_ids", None) is not None:
|
||||
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
|
||||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
@ -634,7 +662,8 @@ def main(args):
|
|||
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
||||
address=gpc.config.alert_address,
|
||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||
)
|
||||
|
||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
|
@ -660,7 +689,6 @@ def main(args):
|
|||
|
||||
# evaluate on validation data loaders
|
||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||
with switch_sequence_parallel_mode():
|
||||
evaluate_on_val_dls(
|
||||
trainer=trainer,
|
||||
val_dls=val_dls,
|
||||
|
@ -670,13 +698,15 @@ def main(args):
|
|||
update_panel=uniscale_logger is not None,
|
||||
)
|
||||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||
# # save batch sampler that tracks the true consumed samples
|
||||
ckpt_save_manager.try_save_checkpoint(train_state)
|
||||
|
||||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
prof.step()
|
||||
|
||||
ckpt_save_manager.wait_async_upload_finish()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue