mirror of https://github.com/InternLM/InternLM
feat(train.py): support torch profiler (#201)
* feat(train.py): support torch profiling * feat(train.py): optimize initialize_llm_profile * feat(train.py): profiling with tp0 and dp0 * move sequence parallel context manager to evalation func * fix lint * move the process for type_ids to load_new_batch * fix lint --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/216/head^2
parent
4832671abe
commit
53648dc0e9
|
@ -218,3 +218,21 @@ def get_megatron_flops(
|
||||||
|
|
||||||
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
||||||
return tflops
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProfile:
|
||||||
|
"""
|
||||||
|
Dummy Profile.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, a, b, c):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
pass
|
||||||
|
|
|
@ -50,112 +50,6 @@ def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape
|
||||||
trainer.schedule._hooks = prev_metric_hooks
|
trainer.schedule._hooks = prev_metric_hooks
|
||||||
|
|
||||||
|
|
||||||
def evaluate_on_val_dls(
|
|
||||||
trainer,
|
|
||||||
val_dls,
|
|
||||||
writer,
|
|
||||||
logger,
|
|
||||||
step_count,
|
|
||||||
update_panel: bool = False,
|
|
||||||
):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
trainer.eval()
|
|
||||||
verbose = gpc.is_rank_for_log()
|
|
||||||
data_cfg = gpc.config.data
|
|
||||||
|
|
||||||
for val_name, val_dl in val_dls.items():
|
|
||||||
if len(val_dl) == 0 and verbose:
|
|
||||||
logger.info(f"Validation dataset: {val_name} is empty")
|
|
||||||
continue
|
|
||||||
|
|
||||||
val_metric = AccPerplex(
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
|
||||||
dp_pg=gpc.get_group(ParallelMode.DATA),
|
|
||||||
)
|
|
||||||
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
|
||||||
|
|
||||||
val_loss = 0
|
|
||||||
val_idx = -1
|
|
||||||
for val_idx, batch in tqdm(
|
|
||||||
enumerate(val_dl),
|
|
||||||
desc="Val.",
|
|
||||||
total=len(val_dl),
|
|
||||||
position=1,
|
|
||||||
disable=not verbose,
|
|
||||||
leave=False,
|
|
||||||
):
|
|
||||||
with torch.inference_mode():
|
|
||||||
if gpc.is_using_pp():
|
|
||||||
total_val_bsz = len(batch[1])
|
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
|
||||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
|
||||||
tensor_shape = torch.Size(
|
|
||||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
|
||||||
)
|
|
||||||
|
|
||||||
with switch_evaluation_pipeline_scheduler(
|
|
||||||
trainer=trainer,
|
|
||||||
num_microbatches=num_microbatches,
|
|
||||||
tensor_shape=tensor_shape,
|
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
|
||||||
):
|
|
||||||
_, _, loss = trainer.execute_schedule(
|
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
total_val_bsz = len(batch[1])
|
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
|
||||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
|
||||||
grad_accum_batch_size = data_cfg.micro_bsz
|
|
||||||
with switch_evaluation_no_pipeline_scheduler(
|
|
||||||
trainer=trainer,
|
|
||||||
grad_accum_size=grad_accum_size,
|
|
||||||
grad_accum_batch_size=grad_accum_batch_size,
|
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
|
||||||
):
|
|
||||||
_, _, loss = trainer.execute_schedule(
|
|
||||||
batch, forward_only=True, return_loss=True, return_output_label=False
|
|
||||||
)
|
|
||||||
if verbose:
|
|
||||||
val_loss += loss.item()
|
|
||||||
|
|
||||||
assert val_idx != -1
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
val_res = val_metric.get_metric()
|
|
||||||
if verbose and len(val_dl) != 0:
|
|
||||||
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
|
||||||
infos = {
|
|
||||||
"step": step_count,
|
|
||||||
f"val/{val_name}_loss": val_loss,
|
|
||||||
f"val/{val_name}_acc": val_res["acc"],
|
|
||||||
f"val/{val_name}_plex": val_res["perplexity"],
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value in infos.items():
|
|
||||||
writer.add_scalar(key=key, value=value, step=step_count)
|
|
||||||
|
|
||||||
if update_panel:
|
|
||||||
logger.info(
|
|
||||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
|
||||||
extra={
|
|
||||||
"step": step_count,
|
|
||||||
"val_loss": val_loss,
|
|
||||||
"val_acc": val_res["acc"],
|
|
||||||
"val_perplexity": val_res["perplexity"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_sequence_parallel_mode():
|
def switch_sequence_parallel_mode():
|
||||||
prev_mode = gpc.config.model.sequence_parallel
|
prev_mode = gpc.config.model.sequence_parallel
|
||||||
|
@ -164,3 +58,110 @@ def switch_sequence_parallel_mode():
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
gpc.config.model.sequence_parallel = prev_mode
|
gpc.config.model.sequence_parallel = prev_mode
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_on_val_dls(
|
||||||
|
trainer,
|
||||||
|
val_dls,
|
||||||
|
writer,
|
||||||
|
logger,
|
||||||
|
step_count,
|
||||||
|
update_panel: bool = False,
|
||||||
|
):
|
||||||
|
with switch_sequence_parallel_mode():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
trainer.eval()
|
||||||
|
verbose = gpc.is_rank_for_log()
|
||||||
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
|
for val_name, val_dl in val_dls.items():
|
||||||
|
if len(val_dl) == 0 and verbose:
|
||||||
|
logger.info(f"Validation dataset: {val_name} is empty")
|
||||||
|
continue
|
||||||
|
|
||||||
|
val_metric = AccPerplex(
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
tp_pg=gpc.get_group(ParallelMode.TENSOR),
|
||||||
|
dp_pg=gpc.get_group(ParallelMode.DATA),
|
||||||
|
)
|
||||||
|
val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
|
||||||
|
|
||||||
|
val_loss = 0
|
||||||
|
val_idx = -1
|
||||||
|
for val_idx, batch in tqdm(
|
||||||
|
enumerate(val_dl),
|
||||||
|
desc="Val.",
|
||||||
|
total=len(val_dl),
|
||||||
|
position=1,
|
||||||
|
disable=not verbose,
|
||||||
|
leave=False,
|
||||||
|
):
|
||||||
|
with torch.inference_mode():
|
||||||
|
if gpc.is_using_pp():
|
||||||
|
total_val_bsz = len(batch[1])
|
||||||
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
|
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||||
|
tensor_shape = torch.Size(
|
||||||
|
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||||
|
)
|
||||||
|
|
||||||
|
with switch_evaluation_pipeline_scheduler(
|
||||||
|
trainer=trainer,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
|
):
|
||||||
|
_, _, loss = trainer.execute_schedule(
|
||||||
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
total_val_bsz = len(batch[1])
|
||||||
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
|
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||||
|
grad_accum_batch_size = data_cfg.micro_bsz
|
||||||
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
|
trainer=trainer,
|
||||||
|
grad_accum_size=grad_accum_size,
|
||||||
|
grad_accum_batch_size=grad_accum_batch_size,
|
||||||
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
|
):
|
||||||
|
_, _, loss = trainer.execute_schedule(
|
||||||
|
batch, forward_only=True, return_loss=True, return_output_label=False
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
assert val_idx != -1
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
val_res = val_metric.get_metric()
|
||||||
|
if verbose and len(val_dl) != 0:
|
||||||
|
val_loss = val_loss / (val_idx + 1 + 1e-6)
|
||||||
|
infos = {
|
||||||
|
"step": step_count,
|
||||||
|
f"val/{val_name}_loss": val_loss,
|
||||||
|
f"val/{val_name}_acc": val_res["acc"],
|
||||||
|
f"val/{val_name}_plex": val_res["perplexity"],
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in infos.items():
|
||||||
|
writer.add_scalar(key=key, value=value, step=step_count)
|
||||||
|
|
||||||
|
if update_panel:
|
||||||
|
logger.info(
|
||||||
|
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()]),
|
||||||
|
extra={
|
||||||
|
"step": step_count,
|
||||||
|
"val_loss": val_loss,
|
||||||
|
"val_acc": val_res["acc"],
|
||||||
|
"val_perplexity": val_res["perplexity"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Validation on {val_name}: " + " ".join([f"{key}={value}" for key, value in infos.items()])
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
dist.barrier()
|
||||||
|
|
180
train.py
180
train.py
|
@ -38,12 +38,13 @@ from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||||
from internlm.utils.common import (
|
from internlm.utils.common import (
|
||||||
BatchSkipper,
|
BatchSkipper,
|
||||||
|
DummyProfile,
|
||||||
get_master_node,
|
get_master_node,
|
||||||
get_megatron_flops,
|
get_megatron_flops,
|
||||||
launch_time,
|
launch_time,
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode
|
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.model_checkpoint import (
|
from internlm.utils.model_checkpoint import (
|
||||||
|
@ -292,6 +293,11 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
|
||||||
train_state.num_consumed_samples_in_epoch = 0
|
train_state.num_consumed_samples_in_epoch = 0
|
||||||
timer("batch-gen").stop()
|
timer("batch-gen").stop()
|
||||||
|
|
||||||
|
if batch[0].get("type_ids", None) is not None:
|
||||||
|
# if use_flash_attn is False, we need to unpack type_ids
|
||||||
|
if not gpc.config.model.use_flash_attn:
|
||||||
|
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"])
|
||||||
|
|
||||||
return batch, train_iter
|
return batch, train_iter
|
||||||
|
|
||||||
|
|
||||||
|
@ -323,6 +329,29 @@ def initialize_optimizer(model: nn.Module):
|
||||||
return optimizer, beta2_scheduler, lr_scheduler
|
return optimizer, beta2_scheduler, lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||||
|
"""Initialize and return the profiler context manager instance."""
|
||||||
|
|
||||||
|
if profiling and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||||
|
llm_profile = torch.profiler.profile
|
||||||
|
logger.info(f"Do profiling in rank {gpc.get_global_rank()}!")
|
||||||
|
else:
|
||||||
|
llm_profile = DummyProfile
|
||||||
|
|
||||||
|
return llm_profile(
|
||||||
|
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||||
|
schedule=torch.profiler.schedule(skip_first=5, wait=1, warmup=1, active=1, repeat=1),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
f"{gpc.config.JOB_NAME}/{start_time}/traces/rank{gpc.get_global_rank()}_"
|
||||||
|
+ f"dp{gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||||
|
+ f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
||||||
|
+ f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}",
|
||||||
|
),
|
||||||
|
with_stack=True,
|
||||||
|
with_modules=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def record_current_batch_training_metrics(
|
def record_current_batch_training_metrics(
|
||||||
get_tflops_func,
|
get_tflops_func,
|
||||||
logger,
|
logger,
|
||||||
|
@ -587,80 +616,79 @@ def main(args):
|
||||||
# transfer the train data loader into train data iterator
|
# transfer the train data loader into train data iterator
|
||||||
train_iter = iter(train_dl)
|
train_iter = iter(train_dl)
|
||||||
|
|
||||||
# start iterating the train data and begin training
|
with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof:
|
||||||
for batch_count in range(train_state.batch_count, total_steps):
|
# start iterating the train data and begin training
|
||||||
if batch_count % 50 == 0:
|
for batch_count in range(train_state.batch_count, total_steps):
|
||||||
torch.cuda.empty_cache()
|
if batch_count % 50 == 0:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
timer("one-batch").start()
|
timer("one-batch").start()
|
||||||
|
|
||||||
# load batch data
|
# load batch data
|
||||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||||
|
|
||||||
|
# record the consumed samples in training
|
||||||
|
train_state.batch_count = batch_count
|
||||||
|
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
||||||
|
if batch_skipper(batch_count): # skip this batch
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(f"Skip batch count:`{batch_count}`...")
|
||||||
|
timer("one-batch").stop()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# zero the grads of parameters
|
||||||
|
trainer.zero_grad()
|
||||||
|
# process data
|
||||||
|
if batch[0].get("type_ids", None) is not None:
|
||||||
|
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))
|
||||||
|
|
||||||
|
# do forward and backward
|
||||||
|
timer("fwd-bwd").start()
|
||||||
|
_, _, loss = trainer.execute_schedule(
|
||||||
|
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||||
|
)
|
||||||
|
timer("fwd-bwd").stop()
|
||||||
|
|
||||||
|
# update parameters, and returns (success_update, grad_norm)
|
||||||
|
trainer_result = trainer.step()
|
||||||
|
assert trainer_result is not None
|
||||||
|
|
||||||
|
success_update, grad_norm_groups = trainer_result
|
||||||
|
if success_update: # update parameters successfully
|
||||||
|
train_state.step_count += 1
|
||||||
|
else:
|
||||||
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
|
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
||||||
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
|
send_alert_message(
|
||||||
|
address=gpc.config.alert_address,
|
||||||
|
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||||
|
record_current_batch_training_metrics(
|
||||||
|
get_tflops_func=get_tflops_func,
|
||||||
|
logger=logger,
|
||||||
|
writer=writer,
|
||||||
|
success_update=success_update,
|
||||||
|
batch_count=batch_count,
|
||||||
|
batch=batch,
|
||||||
|
train_state=train_state,
|
||||||
|
optimizer=optimizer,
|
||||||
|
beta2_scheduler=beta2_scheduler,
|
||||||
|
trainer=trainer,
|
||||||
|
start_time=start_time,
|
||||||
|
loss=loss,
|
||||||
|
grad_norm=np.array(grad_norm_groups),
|
||||||
|
metric=metric,
|
||||||
|
update_panel=uniscale_logger is not None,
|
||||||
|
)
|
||||||
|
|
||||||
# record the consumed samples in training
|
|
||||||
train_state.batch_count = batch_count
|
|
||||||
train_state.num_consumed_samples_in_epoch += len(batch[1])
|
|
||||||
if batch_skipper(batch_count): # skip this batch
|
|
||||||
if gpc.is_rank_for_log():
|
|
||||||
logger.info(f"Skip batch count:`{batch_count}`...")
|
|
||||||
timer("one-batch").stop()
|
timer("one-batch").stop()
|
||||||
continue
|
|
||||||
|
|
||||||
# zero the grads of parameters
|
# evaluate on validation data loaders
|
||||||
trainer.zero_grad()
|
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
||||||
type_ids = batch[0].pop("type_ids", None)
|
|
||||||
# process data
|
|
||||||
# if use_flash_attn is False, we need to unpack type_ids
|
|
||||||
if not gpc.config.model.use_flash_attn:
|
|
||||||
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
|
|
||||||
if type_ids is not None:
|
|
||||||
metric.set_current_type_ids(type_ids=type_ids)
|
|
||||||
|
|
||||||
# do forward and backward
|
|
||||||
timer("fwd-bwd").start()
|
|
||||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
|
||||||
timer("fwd-bwd").stop()
|
|
||||||
|
|
||||||
# update parameters, and returns (success_update, grad_norm)
|
|
||||||
trainer_result = trainer.step()
|
|
||||||
assert trainer_result is not None
|
|
||||||
|
|
||||||
success_update, grad_norm_groups = trainer_result
|
|
||||||
if success_update: # update parameters successfully
|
|
||||||
train_state.step_count += 1
|
|
||||||
else:
|
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
|
||||||
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
|
||||||
send_alert_message(
|
|
||||||
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
|
||||||
record_current_batch_training_metrics(
|
|
||||||
get_tflops_func=get_tflops_func,
|
|
||||||
logger=logger,
|
|
||||||
writer=writer,
|
|
||||||
success_update=success_update,
|
|
||||||
batch_count=batch_count,
|
|
||||||
batch=batch,
|
|
||||||
train_state=train_state,
|
|
||||||
optimizer=optimizer,
|
|
||||||
beta2_scheduler=beta2_scheduler,
|
|
||||||
trainer=trainer,
|
|
||||||
start_time=start_time,
|
|
||||||
loss=loss,
|
|
||||||
grad_norm=np.array(grad_norm_groups),
|
|
||||||
metric=metric,
|
|
||||||
update_panel=uniscale_logger is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
timer("one-batch").stop()
|
|
||||||
|
|
||||||
# evaluate on validation data loaders
|
|
||||||
if valid_every > 0 and train_state.step_count % valid_every == 0:
|
|
||||||
with switch_sequence_parallel_mode():
|
|
||||||
evaluate_on_val_dls(
|
evaluate_on_val_dls(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
val_dls=val_dls,
|
val_dls=val_dls,
|
||||||
|
@ -670,12 +698,14 @@ def main(args):
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if memory_profiler is not None:
|
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
||||||
memory_profiler.step()
|
# # save batch sampler that tracks the true consumed samples
|
||||||
|
ckpt_save_manager.try_save_checkpoint(train_state)
|
||||||
|
|
||||||
# checkpoint the training states in specific steps, which is determined by the args "checkpoint_every"
|
if memory_profiler is not None:
|
||||||
# # save batch sampler that tracks the true consumed samples
|
memory_profiler.step()
|
||||||
ckpt_save_manager.try_save_checkpoint(train_state)
|
|
||||||
|
prof.step()
|
||||||
|
|
||||||
ckpt_save_manager.wait_async_upload_finish()
|
ckpt_save_manager.wait_async_upload_finish()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue